Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
beat
beat.core
Commits
0403717b
Commit
0403717b
authored
Nov 27, 2017
by
Philip ABBET
Browse files
[unittests] Refactoring of the 'test_cacheddata.py' file
parent
a8871304
Changes
1
Hide whitespace changes
Inline
Side-by-side
beat/core/test/test_cacheddata.py
100644 → 100755
View file @
0403717b
...
...
@@ -3,7 +3,7 @@
###############################################################################
# #
# Copyright (c) 201
6
Idiap Research Institute, http://www.idiap.ch/ #
# Copyright (c) 201
7
Idiap Research Institute, http://www.idiap.ch/ #
# Contact: beat.support@idiap.ch #
# #
# This file is part of the beat.core module of the BEAT platform. #
...
...
@@ -26,234 +26,170 @@
###############################################################################
import
unittest
import
os
import
glob
import
tempfile
import
six
import
numpy
import
nose.tools
from
..data
import
CachedDataSink
,
CachedDataSource
,
foundSplitRanges
from
..hash
import
hashFileContents
from
..dataformat
import
DataFormat
from
.
import
prefix
testfile
=
None
def
create_tempfile
():
global
testfile
testfile
=
tempfile
.
NamedTemporaryFile
(
prefix
=
__name__
,
suffix
=
'.data'
)
testfile
.
close
()
#preserve only name
#----------------------------------------------------------
class
TestCachedDataBase
(
unittest
.
TestCase
):
def
erase_tempfiles
():
global
testfile
basename
,
data_ext
=
os
.
path
.
splitext
(
testfile
.
name
)
filenames
=
[
testfile
.
name
]
filenames
+=
glob
.
glob
(
basename
+
'*'
+
data_ext
)
filenames
+=
glob
.
glob
(
basename
+
'*'
+
data_ext
+
'.checksum'
)
filenames
+=
glob
.
glob
(
basename
+
'*.index'
)
filenames
+=
glob
.
glob
(
basename
+
'*.index.checksum'
)
for
filename
in
filenames
:
if
os
.
path
.
exists
(
filename
):
os
.
unlink
(
filename
)
def
setUp
(
self
):
testfile
=
tempfile
.
NamedTemporaryFile
(
prefix
=
__name__
,
suffix
=
'.data'
)
testfile
.
close
()
# preserve only the name
self
.
filename
=
testfile
.
name
@
nose
.
tools
.
with_setup
(
create_tempfile
,
erase_tempfiles
)
def
test_data_sink_creation
():
dataformat
=
DataFormat
(
prefix
,
'user/integers/1'
)
assert
dataformat
.
valid
def
tearDown
(
self
):
basename
,
ext
=
os
.
path
.
splitext
(
self
.
filename
)
filenames
=
[
self
.
filename
]
filenames
+=
glob
.
glob
(
basename
+
'*'
)
for
filename
in
filenames
:
if
os
.
path
.
exists
(
filename
):
os
.
unlink
(
filename
)
data_sink
=
CachedDataSink
()
assert
data_sink
.
setup
(
testfile
.
name
,
dataformat
)
@
nose
.
tools
.
with_setup
(
create_tempfile
,
erase_tempfiles
)
def
test_data_source_creation
():
def
writeData
(
self
,
dataformat_name
,
start_index
=
0
,
end_index
=
10
):
dataformat
=
DataFormat
(
prefix
,
dataformat_name
)
self
.
assertTrue
(
dataformat
.
valid
)
f
=
open
(
testfile
.
name
,
'wb'
)
f
.
write
(
b
'json
\n
user/integers/1
\n
'
)
f
.
close
()
data_sink
=
CachedDataSink
()
self
.
assertTrue
(
data_sink
.
setup
(
self
.
filename
,
dataformat
))
chksum_data
=
hashFileContents
(
testfile
.
name
)
all_data
=
[]
for
i
in
range
(
start_index
,
end_index
+
1
):
data
=
dataformat
.
type
()
data_sink
.
write
(
data
,
i
,
i
)
all_data
.
append
(
data
)
with
open
(
testfile
.
name
+
'.checksum'
,
'wt'
)
as
f
:
f
.
write
(
chksum_data
)
(
nb_bytes
,
duration
)
=
data_sink
.
statistics
()
self
.
assertTrue
(
nb_bytes
>
0
)
self
.
assertTrue
(
duration
>
0
)
data_source
=
CachedDataSource
()
data_sink
.
close
()
del
data_sink
assert
data_source
.
setup
(
testfile
.
name
,
prefix
)
assert
data_source
.
dataformat
.
valid
assert
not
data_source
.
hasMoreData
()
return
all_data
(
data
,
start_index
,
end_index
)
=
data_source
.
next
()
assert
data
is
None
assert
start_index
is
None
assert
end_index
is
None
data_source
.
close
()
#----------------------------------------------------------
def
test_cached_data_split
():
l
=
[[
0
,
2
,
4
,
6
,
8
,
10
,
12
],[
0
,
3
,
6
,
9
,
12
]]
n_split
=
2
ref
=
[(
0
,
5
),
(
6
,
11
)]
res
=
foundSplitRanges
(
l
,
n_split
)
nose
.
tools
.
eq_
(
res
,
ref
)
class
TestDataSink
(
TestCachedDataBase
):
l
=
[[
0
,
2
,
4
,
6
,
8
,
10
,
12
,
15
],[
0
,
3
,
6
,
9
,
12
,
15
]]
n_split
=
5
ref
=
[(
0
,
5
),
(
6
,
11
),
(
12
,
14
)]
res
=
foundSplitRanges
(
l
,
n_split
)
nose
.
tools
.
eq_
(
res
,
ref
)
def
test_creation
(
self
):
dataformat
=
DataFormat
(
prefix
,
'user/integers/1'
)
self
.
assertTrue
(
dataformat
.
valid
)
def
serialization
(
format_name
,
data_modifier
=
None
,
data_tester
=
None
):
data_sink
=
CachedDataSink
()
self
.
assertTrue
(
data_sink
.
setup
(
self
.
filename
,
dataformat
))
dataformat
=
DataFormat
(
prefix
,
format_name
)
assert
dataformat
.
valid
data_sink
=
CachedDataSink
()
assert
data_sink
.
setup
(
testfile
.
name
,
dataformat
)
#----------------------------------------------------------
for
i
in
six
.
moves
.
range
(
0
,
5
):
data
=
dataformat
.
type
()
if
data_modifier
is
not
None
:
data_modifier
(
data
)
data_sink
.
write
(
data
,
i
,
i
)
(
nb_bytes
,
duration
)
=
data_sink
.
statistics
()
assert
nb_bytes
>
0
assert
duration
>
0
class
TestDataSource
(
TestCachedDataBase
):
d
ata_sink
.
close
()
del
data_sink
d
ef
test_creation
(
self
):
self
.
writeData
(
'user/integers/1'
)
data_source
=
CachedDataSource
()
assert
data_source
.
setup
(
testfile
.
name
,
prefix
)
data_source
=
CachedDataSource
()
for
i
in
six
.
moves
.
range
(
0
,
5
):
assert
data_source
.
hasMoreData
()
self
.
assertTrue
(
data_source
.
setup
(
self
.
filename
,
prefix
))
self
.
assertTrue
(
data_source
.
dataformat
.
valid
)
self
.
assertTrue
(
data_source
.
hasMoreData
())
(
data
,
start_index
,
end_index
)
=
data_source
.
next
()
assert
data
is
not
None
nose
.
tools
.
eq_
(
start_index
,
i
)
nose
.
tools
.
eq_
(
end_index
,
i
)
data_source
.
close
()
if
data_tester
is
not
None
:
data_tester
(
data
)
assert
not
data_source
.
hasMoreData
()
def
perform_deserialization
(
self
,
dataformat_name
,
start_index
=
0
,
end_index
=
10
):
reference
=
self
.
writeData
(
dataformat_name
)
# Always generate 10 data units
(
nb_bytes
,
duration
)
=
data_source
.
statistics
()
assert
nb_bytes
>
0
assert
duration
>
0
data_source
=
CachedDataSource
()
data_source
.
close
()
self
.
assertTrue
(
data_source
.
setup
(
self
.
filename
,
prefix
,
force_start_index
=
start_index
,
force_end_index
=
end_index
))
@
nose
.
tools
.
with_setup
(
create_tempfile
,
erase_tempfiles
)
def
test_integers
():
self
.
assertTrue
(
data_source
.
dataformat
.
valid
)
def
data_modifier
(
data
):
data
.
value8
=
numpy
.
int8
(
2
**
6
)
data
.
value16
=
numpy
.
int16
(
2
**
14
)
data
.
value32
=
numpy
.
int32
(
2
**
30
)
data
.
value64
=
numpy
.
int64
(
2
**
62
)
for
i
in
range
(
start_index
,
end_index
+
1
):
self
.
assertTrue
(
data_source
.
hasMoreData
())
def
data_tester
(
data
):
nose
.
tools
.
eq_
(
data
.
value8
,
numpy
.
int8
(
2
**
6
))
nose
.
tools
.
eq_
(
data
.
value16
,
numpy
.
int16
(
2
**
14
))
nose
.
tools
.
eq_
(
data
.
value32
,
numpy
.
int32
(
2
**
30
))
nose
.
tools
.
eq_
(
data
.
value64
,
numpy
.
int64
(
2
**
62
))
(
data
,
start
,
end
)
=
data_source
.
next
()
self
.
assertTrue
(
data
is
not
None
)
serialization
(
'user/integers/1'
,
data_modifier
=
data_modifier
,
data_tester
=
data_tester
)
self
.
assertEqual
(
i
,
start
)
self
.
assertEqual
(
i
,
end
)
self
.
assertEqual
(
reference
[
i
].
as_dict
(),
data
.
as_dict
())
@
nose
.
tools
.
with_setup
(
create_tempfile
,
erase_tempfiles
)
def
test_objects
():
serialization
(
'user/two_objects/1'
)
self
.
assertFalse
(
data_source
.
hasMoreData
())
@
nose
.
tools
.
with_setup
(
create_tempfile
,
erase_tempfiles
)
def
test_hierarchy_of_objects
():
se
rialization
(
'user/hierarchy_of_objects/1'
)
(
nb_bytes
,
duration
)
=
data_source
.
statistics
(
)
self
.
assertTrue
(
nb_bytes
>
0
)
se
lf
.
assertTrue
(
duration
>
0
)
@
nose
.
tools
.
with_setup
(
create_tempfile
,
erase_tempfiles
)
def
test_3d_array_of_integers
():
serialization
(
'user/3d_array_of_integers/1'
)
data_source
.
close
()
@
nose
.
tools
.
with_setup
(
create_tempfile
,
erase_tempfiles
)
def
test_3d_array_of_objects
():
serialization
(
'user/3d_array_of_objects/1'
)
def
splitting
(
format_name
,
data_modifier
=
None
,
data_tester
=
None
):
def
test_integers
(
self
):
self
.
perform_deserialization
(
'user/integers/1'
)
dataformat
=
DataFormat
(
prefix
,
format_name
)
assert
dataformat
.
valid
d
ata_sink
=
CachedDataSink
()
assert
data_sink
.
setup
(
testfile
.
name
,
dataformat
)
d
ef
test_objects
(
self
):
self
.
perform_deserialization
(
'user/two_objects/1'
)
for
i
in
six
.
moves
.
range
(
0
,
10
):
data
=
dataformat
.
type
()
if
data_modifier
is
not
None
:
data_modifier
(
data
)
data_sink
.
write
(
data
,
i
,
i
)
(
nb_bytes
,
duration
)
=
data_sink
.
statistics
()
assert
nb_bytes
>
0
assert
duration
>
0
def
test_hierarchy_of_objects
(
self
):
self
.
perform_deserialization
(
'user/hierarchy_of_objects/1'
)
data_sink
.
close
()
force_start_index
=
3
force_end_index
=
8
data_source
=
CachedDataSource
()
assert
data_source
.
setup
(
testfile
.
name
,
prefix
,
force_start_index
=
force_start_index
,
force_end_index
=
force_end_index
)
def
test_3d_array_of_integers
(
self
):
self
.
perform_deserialization
(
'user/3d_array_of_integers/1'
)
for
i
in
six
.
moves
.
range
(
force_start_index
,
force_end_index
+
1
):
assert
data_source
.
hasMoreData
()
(
data
,
start_index
,
end_index
)
=
data_source
.
next
()
assert
data
is
not
None
nose
.
tools
.
eq_
(
start_index
,
i
)
nose
.
tools
.
eq_
(
end_index
,
i
)
def
test_3d_array_of_objects
(
self
):
self
.
perform_deserialization
(
'user/3d_array_of_objects/1'
)
if
data_tester
is
not
None
:
data_tester
(
data
)
assert
not
data_source
.
hasMoreData
()
def
test_integers_slice_1
(
self
):
self
.
perform_deserialization
(
'user/integers/1'
,
0
,
4
)
(
nb_bytes
,
duration
)
=
data_source
.
statistics
()
assert
nb_bytes
>
0
assert
duration
>
0
data_source
.
close
()
def
test_integers_slice_2
(
self
):
self
.
perform_deserialization
(
'user/integers/1'
,
3
,
6
)
@
nose
.
tools
.
with_setup
(
create_tempfile
,
erase_tempfiles
)
def
test_integers_splitting
():
def
test_integers_slice_3
(
self
):
self
.
perform_deserialization
(
'user/integers/1'
,
7
,
10
)
def
data_modifier
(
data
):
data
.
value8
=
numpy
.
int8
(
2
**
6
)
data
.
value16
=
numpy
.
int16
(
2
**
14
)
data
.
value32
=
numpy
.
int32
(
2
**
30
)
data
.
value64
=
numpy
.
int64
(
2
**
62
)
def
data_tester
(
data
):
nose
.
tools
.
eq_
(
data
.
value8
,
numpy
.
int8
(
2
**
6
))
nose
.
tools
.
eq_
(
data
.
value16
,
numpy
.
int16
(
2
**
14
))
nose
.
tools
.
eq_
(
data
.
value32
,
numpy
.
int32
(
2
**
30
))
nose
.
tools
.
eq_
(
data
.
value64
,
numpy
.
int64
(
2
**
62
))
#----------------------------------------------------------
splitting
(
'user/integers/1'
,
data_modifier
=
data_modifier
,
data_tester
=
data_tester
)
@
nose
.
tools
.
with_setup
(
create_tempfile
,
erase_tempfiles
)
def
test_objects_splitting
():
splitting
(
'user/two_objects/1'
)
class
TestFoundSplitRanges
(
unittest
.
TestCase
):
@
nose
.
tools
.
with_setup
(
create_tempfile
,
erase_tempfiles
)
def
test_hierarchy_of_objects_splitting
():
splitting
(
'user/hierarchy_of_objects/1'
)
def
test_2_splits
(
self
):
l
=
[[
0
,
2
,
4
,
6
,
8
,
10
,
12
],
[
0
,
3
,
6
,
9
,
12
]]
n_split
=
2
ref
=
[(
0
,
5
),
(
6
,
11
)]
res
=
foundSplitRanges
(
l
,
n_split
)
self
.
assertEqual
(
res
,
ref
)
@
nose
.
tools
.
with_setup
(
create_tempfile
,
erase_tempfiles
)
def
test_3d_array_of_integers_splitting
():
splitting
(
'user/3d_array_of_integers/1'
)
@
nose
.
tools
.
with_setup
(
create_tempfile
,
erase_tempfiles
)
def
test_3d_array_of_objects_splitting
():
splitting
(
'user/3d_array_of_objects/1'
)
def
test_5_splits
(
self
):
l
=
[[
0
,
2
,
4
,
6
,
8
,
10
,
12
,
15
],
[
0
,
3
,
6
,
9
,
12
,
15
]]
n_split
=
5
ref
=
[(
0
,
5
),
(
6
,
11
),
(
12
,
14
)]
res
=
foundSplitRanges
(
l
,
n_split
)
self
.
assertEqual
(
res
,
ref
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment