Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
ce620ce6
Commit
ce620ce6
authored
Apr 25, 2019
by
Pavel KORSHUNOV
Browse files
tfrecords: added fixed batch shuffling and allow lists of tfrecords
parent
a11d5d41
Pipeline
#29590
failed with stage
in 164 minutes and 41 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/dataset/tfrecords.py
View file @
ce620ce6
...
...
@@ -297,7 +297,7 @@ def create_dataset_from_records_with_augmentation(
if
feature
is
None
:
feature
=
DEFAULT_FEATURE
if
os
.
path
.
isdir
(
tfrecord_filenames
):
if
isinstance
(
tfrecord_filenames
,
str
)
and
os
.
path
.
isdir
(
tfrecord_filenames
):
tfrecord_filenames
=
[
os
.
path
.
join
(
tfrecord_filenames
,
f
)
for
f
in
os
.
listdir
(
tfrecord_filenames
)
]
...
...
@@ -339,6 +339,7 @@ def shuffle_data_and_labels_image_augmentation(
per_image_normalization
=
True
,
random_gamma
=
False
,
random_crop
=
False
,
fixed_batch_size
=
False
,
):
"""
Dump random batches from a list of tf-record files and applies some image augmentation
...
...
@@ -384,9 +385,12 @@ def shuffle_data_and_labels_image_augmentation(
random_rotate:
Randomly rotate face images between -5 and 5 degrees
per_image_normalization:
per_image_normalization:
Linearly scales image to have zero mean and unit norm.
fixed_batch_size:
If True, the last remaining batch that has smaller size than `batch_size' will be dropped.
"""
dataset
=
create_dataset_from_records_with_augmentation
(
...
...
@@ -405,7 +409,12 @@ def shuffle_data_and_labels_image_augmentation(
random_crop
=
random_crop
,
)
dataset
=
dataset
.
shuffle
(
buffer_size
).
batch
(
batch_size
).
repeat
(
epochs
)
dataset
=
dataset
.
shuffle
(
buffer_size
)
if
fixed_batch_size
:
dataset
=
dataset
.
apply
(
tf
.
contrib
.
data
.
batch_and_drop_remainder
(
batch_size
))
else
:
dataset
=
dataset
.
batch
(
batch_size
)
dataset
=
dataset
.
repeat
(
epochs
)
dataset
=
dataset
.
map
(
lambda
d
,
l
,
k
:
({
"data"
:
d
,
"key"
:
k
},
l
))
...
...
@@ -507,6 +516,7 @@ def batch_data_and_labels_image_augmentation(
per_image_normalization
=
True
,
random_gamma
=
False
,
random_crop
=
False
,
fixed_batch_size
=
False
,
):
"""
Dump in order batches from a list of tf-record files
...
...
@@ -528,6 +538,9 @@ def batch_data_and_labels_image_augmentation(
epochs:
Number of epochs to be batched
fixed_batch_size:
If True, the last remaining batch that has smaller size than `batch_size' will be dropped.
"""
dataset
=
create_dataset_from_records_with_augmentation
(
...
...
@@ -546,7 +559,11 @@ def batch_data_and_labels_image_augmentation(
random_crop
=
random_crop
,
)
dataset
=
dataset
.
batch
(
batch_size
).
repeat
(
epochs
)
if
fixed_batch_size
:
dataset
=
dataset
.
apply
(
tf
.
contrib
.
data
.
batch_and_drop_remainder
(
batch_size
))
else
:
dataset
=
dataset
.
batch
(
batch_size
)
dataset
=
dataset
.
repeat
(
epochs
)
data
,
labels
,
key
=
dataset
.
make_one_shot_iterator
().
get_next
()
features
=
dict
()
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment