Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
bc57c932
Commit
bc57c932
authored
Sep 09, 2016
by
Tiago de Freitas Pereira
Browse files
Implemented the prefetch for the siamese trainer
parent
86f2c9cf
Changes
13
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/analyzers/Analizer.py
View file @
bc57c932
...
...
@@ -45,14 +45,14 @@ class Analizer:
def
__call__
(
self
):
# Extracting features for enrollment
enroll_data
,
enroll_labels
=
self
.
data_shuffler
.
get_batch
(
train_dataset
=
False
)
enroll_data
,
enroll_labels
=
self
.
data_shuffler
.
get_batch
()
enroll_features
=
self
.
machine
(
enroll_data
,
session
=
self
.
session
)
del
enroll_data
#import ipdb; ipdb.set_trace();
# Extracting features for probing
probe_data
,
probe_labels
=
self
.
data_shuffler
.
get_batch
(
train_dataset
=
False
)
probe_data
,
probe_labels
=
self
.
data_shuffler
.
get_batch
()
probe_features
=
self
.
machine
(
probe_data
,
session
=
self
.
session
)
del
probe_data
...
...
bob/learn/tensorflow/data/BaseDataShuffler.py
View file @
bc57c932
...
...
@@ -11,10 +11,8 @@ class BaseDataShuffler(object):
def
__init__
(
self
,
data
,
labels
,
input_shape
,
input_dtype
=
"float64"
,
perc_train
=
0.9
,
scale
=
True
,
train_batch_size
=
1
,
validation_batch_size
=
300
):
batch_size
=
1
):
"""
The class provide base functionoalies to shuffle the data
...
...
@@ -32,54 +30,36 @@ class BaseDataShuffler(object):
self
.
input_dtype
=
input_dtype
# TODO: Check if the bacth size is higher than the input data
self
.
train_batch_size
=
train_batch_size
self
.
validation_batch_size
=
validation_batch_size
self
.
batch_size
=
batch_size
self
.
data
=
data
self
.
train_shape
=
tuple
([
train_batch_size
]
+
input_shape
)
self
.
validation_shape
=
tuple
([
validation_batch_size
]
+
input_shape
)
self
.
shape
=
tuple
([
batch_size
]
+
input_shape
)
# TODO: Check if the labels goes from O to N-1
self
.
labels
=
labels
self
.
possible_labels
=
list
(
set
(
self
.
labels
))
# Computing the data samples fro train and validation
self
.
n_samples
=
len
(
self
.
labels
)
self
.
n_train_samples
=
int
(
round
(
self
.
n_samples
*
perc_train
))
self
.
n_validation_samples
=
self
.
n_samples
-
self
.
n_train_samples
# Shuffling all the indexes
self
.
indexes
=
numpy
.
array
(
range
(
self
.
n_samples
))
numpy
.
random
.
shuffle
(
self
.
indexes
)
# Spliting the data between train and validation
self
.
train_data
=
self
.
data
[
self
.
indexes
[
0
:
self
.
n_train_samples
],
...]
self
.
train_labels
=
self
.
labels
[
self
.
indexes
[
0
:
self
.
n_train_samples
]]
self
.
validation_data
=
self
.
data
[
self
.
indexes
[
self
.
n_train_samples
:
self
.
n_train_samples
+
self
.
n_validation_samples
],
...]
self
.
validation_labels
=
self
.
labels
[
self
.
indexes
[
self
.
n_train_samples
:
self
.
n_train_samples
+
self
.
n_validation_samples
]]
def
get_placeholders_forprefetch
(
self
,
name
=
""
,
train_dataset
=
True
):
def
get_placeholders_forprefetch
(
self
,
name
=
""
):
"""
Returns a place holder with the size of your batch
"""
shape
=
self
.
train_shape
if
train_dataset
else
self
.
validation_shape
data
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
tuple
([
None
]
+
list
(
shape
[
1
:])),
name
=
name
)
data
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
tuple
([
None
]
+
list
(
self
.
shape
[
1
:])),
name
=
name
)
labels
=
tf
.
placeholder
(
tf
.
int64
,
shape
=
[
None
,
])
return
data
,
labels
def
get_placeholders
(
self
,
name
=
""
,
train_dataset
=
True
):
def
get_placeholders
(
self
,
name
=
""
):
"""
Returns a place holder with the size of your batch
"""
shape
=
self
.
train_shape
if
train_dataset
else
self
.
validation_shape
data
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
shape
,
name
=
name
)
labels
=
tf
.
placeholder
(
tf
.
int64
,
shape
=
shape
[
0
])
data
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
self
.
shape
,
name
=
name
)
labels
=
tf
.
placeholder
(
tf
.
int64
,
shape
=
self
.
shape
[
0
])
return
data
,
labels
...
...
bob/learn/tensorflow/data/MemoryDataShuffler.py
View file @
bc57c932
...
...
@@ -20,10 +20,8 @@ class MemoryDataShuffler(BaseDataShuffler):
def
__init__
(
self
,
data
,
labels
,
input_shape
,
input_dtype
=
"float64"
,
perc_train
=
0.9
,
scale
=
True
,
train_batch_size
=
1
,
validation_batch_size
=
300
):
batch_size
=
1
):
"""
Shuffler that deal with memory datasets
...
...
@@ -41,36 +39,22 @@ class MemoryDataShuffler(BaseDataShuffler):
labels
=
labels
,
input_shape
=
input_shape
,
input_dtype
=
input_dtype
,
perc_train
=
perc_train
,
scale
=
scale
,
train_batch_size
=
train_batch_size
,
validation_batch_size
=
validation_batch_size
batch_size
=
batch_size
)
self
.
train_data
=
self
.
train_data
.
astype
(
input_dtype
)
self
.
validation_data
=
self
.
validation_data
.
astype
(
input_dtype
)
self
.
data
=
self
.
data
.
astype
(
input_dtype
)
if
self
.
scale
:
self
.
train_data
*=
self
.
scale_value
self
.
validation_data
*=
self
.
scale_value
def
get_batch
(
self
,
train_dataset
=
True
):
self
.
data
*=
self
.
scale_value
if
train_dataset
:
n_samples
=
self
.
train_batch_size
data
=
self
.
train_data
label
=
self
.
train_labels
else
:
n_samples
=
self
.
validation_batch_size
data
=
self
.
validation_data
label
=
self
.
validation_labels
def
get_batch
(
self
):
# Shuffling samples
indexes
=
numpy
.
array
(
range
(
data
.
shape
[
0
]))
indexes
=
numpy
.
array
(
range
(
self
.
data
.
shape
[
0
]))
numpy
.
random
.
shuffle
(
indexes
)
selected_data
=
data
[
indexes
[
0
:
n_samples
],
:,
:,
:]
selected_labels
=
label
[
indexes
[
0
:
n_samples
]]
selected_data
=
self
.
data
[
indexes
[
0
:
self
.
batch_size
],
:,
:,
:]
selected_labels
=
self
.
label
s
[
indexes
[
0
:
self
.
batch_size
]]
return
selected_data
.
astype
(
"float32"
),
selected_labels
...
...
@@ -83,23 +67,13 @@ class MemoryDataShuffler(BaseDataShuffler):
**Return**
"""
if
train_dataset
:
target_data
=
self
.
train_data
target_labels
=
self
.
train_labels
shape
=
self
.
train_shape
else
:
target_data
=
self
.
validation_data
target_labels
=
self
.
validation_labels
shape
=
self
.
validation_shape
data
=
numpy
.
zeros
(
shape
=
shape
,
dtype
=
'float32'
)
data_p
=
numpy
.
zeros
(
shape
=
shape
,
dtype
=
'float32'
)
labels_siamese
=
numpy
.
zeros
(
shape
=
shape
[
0
],
dtype
=
'float32'
)
data
=
numpy
.
zeros
(
shape
=
self
.
shape
,
dtype
=
'float32'
)
data_p
=
numpy
.
zeros
(
shape
=
self
.
shape
,
dtype
=
'float32'
)
labels_siamese
=
numpy
.
zeros
(
shape
=
self
.
shape
[
0
],
dtype
=
'float32'
)
genuine
=
True
for
i
in
range
(
shape
[
0
]):
data
[
i
,
...],
data_p
[
i
,
...]
=
self
.
get_genuine_or_not
(
target_data
,
target_
labels
,
genuine
=
genuine
)
for
i
in
range
(
self
.
shape
[
0
]):
data
[
i
,
...],
data_p
[
i
,
...]
=
self
.
get_genuine_or_not
(
self
.
data
,
self
.
labels
,
genuine
=
genuine
)
if
zero_one_labels
:
labels_siamese
[
i
]
=
not
genuine
else
:
...
...
@@ -107,3 +81,62 @@ class MemoryDataShuffler(BaseDataShuffler):
genuine
=
not
genuine
return
data
,
data_p
,
labels_siamese
def
get_triplet
(
self
,
n_labels
,
n_triplets
=
1
,
is_target_set_train
=
True
):
"""
Get a triplet
**Parameters**
is_target_set_train: Defining the target set to get the batch
**Return**
"""
def
get_one_triplet
(
input_data
,
input_labels
):
# Getting a pair of clients
index
=
numpy
.
random
.
choice
(
n_labels
,
2
,
replace
=
False
)
label_positive
=
index
[
0
]
label_negative
=
index
[
1
]
# Getting the indexes of the data from a particular client
indexes
=
numpy
.
where
(
input_labels
==
index
[
0
])[
0
]
numpy
.
random
.
shuffle
(
indexes
)
# Picking a positive pair
data_anchor
=
input_data
[
indexes
[
0
],
:,
:,
:]
data_positive
=
input_data
[
indexes
[
1
],
:,
:,
:]
# Picking a negative sample
indexes
=
numpy
.
where
(
input_labels
==
index
[
1
])[
0
]
numpy
.
random
.
shuffle
(
indexes
)
data_negative
=
input_data
[
indexes
[
0
],
:,
:,
:]
return
data_anchor
,
data_positive
,
data_negative
,
label_positive
,
label_positive
,
label_negative
if
is_target_set_train
:
target_data
=
self
.
train_data
target_labels
=
self
.
train_labels
else
:
target_data
=
self
.
validation_data
target_labels
=
self
.
validation_labels
c
=
target_data
.
shape
[
3
]
w
=
target_data
.
shape
[
1
]
h
=
target_data
.
shape
[
2
]
data_a
=
numpy
.
zeros
(
shape
=
(
n_triplets
,
w
,
h
,
c
),
dtype
=
'float32'
)
data_p
=
numpy
.
zeros
(
shape
=
(
n_triplets
,
w
,
h
,
c
),
dtype
=
'float32'
)
data_n
=
numpy
.
zeros
(
shape
=
(
n_triplets
,
w
,
h
,
c
),
dtype
=
'float32'
)
labels_a
=
numpy
.
zeros
(
shape
=
n_triplets
,
dtype
=
'float32'
)
labels_p
=
numpy
.
zeros
(
shape
=
n_triplets
,
dtype
=
'float32'
)
labels_n
=
numpy
.
zeros
(
shape
=
n_triplets
,
dtype
=
'float32'
)
for
i
in
range
(
n_triplets
):
data_a
[
i
,
:,
:,
:],
data_p
[
i
,
:,
:,
:],
data_n
[
i
,
:,
:,
:],
\
labels_a
[
i
],
labels_p
[
i
],
labels_n
[
i
]
=
\
get_one_triplet
(
target_data
,
target_labels
)
return
data_a
,
data_p
,
data_n
,
labels_a
,
labels_p
,
labels_n
bob/learn/tensorflow/data/MemoryPairDataShuffler.py
deleted
100644 → 0
View file @
86f2c9cf
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import
numpy
from
.MemoryDataShuffler
import
MemoryDataShuffler
class
MemoryPairDataShuffler
(
MemoryDataShuffler
):
def
__init__
(
self
,
data
,
labels
,
input_shape
,
perc_train
=
0.9
,
scale
=
True
,
train_batch_size
=
1
,
validation_batch_size
=
300
):
"""
The class provide some functionalities for shuffling data
**Parameters**
data:
"""
data
=
data
labels
=
labels
input_shape
=
input_shape
perc_train
=
perc_train
scale
=
scale
train_batch_size
=
train_batch_size
validation_batch_size
=
validation_batch_size
super
(
MemoryPairDataShuffler
,
self
).
__init__
(
data
,
labels
,
input_shape
=
input_shape
,
perc_train
=
perc_train
,
scale
=
scale
,
train_batch_size
=
train_batch_size
*
2
,
validation_batch_size
=
validation_batch_size
)
def
get_pair
(
self
,
train_dataset
=
True
,
zero_one_labels
=
True
):
"""
Get a random pair of samples
**Parameters**
is_target_set_train: Defining the target set to get the batch
**Return**
"""
def
get_genuine_or_not
(
input_data
,
input_labels
,
genuine
=
True
):
if
genuine
:
# TODO: THIS KEY SELECTION NEEDS TO BE MORE EFFICIENT
# Getting a client
index
=
numpy
.
random
.
randint
(
self
.
total_labels
)
# Getting the indexes of the data from a particular client
indexes
=
numpy
.
where
(
input_labels
==
index
)[
0
]
numpy
.
random
.
shuffle
(
indexes
)
# Picking a pair
data
=
input_data
[
indexes
[
0
],
...]
data_p
=
input_data
[
indexes
[
1
],
...]
else
:
# Picking a pair from different clients
index
=
numpy
.
random
.
choice
(
self
.
total_labels
,
2
,
replace
=
False
)
# Getting the indexes of the two clients
indexes
=
numpy
.
where
(
input_labels
==
index
[
0
])[
0
]
indexes_p
=
numpy
.
where
(
input_labels
==
index
[
1
])[
0
]
numpy
.
random
.
shuffle
(
indexes
)
numpy
.
random
.
shuffle
(
indexes_p
)
# Picking a pair
data
=
input_data
[
indexes
[
0
],
...]
data_p
=
input_data
[
indexes_p
[
0
],
...]
return
data
,
data_p
if
train_dataset
:
target_data
=
self
.
train_data
target_labels
=
self
.
train_labels
shape
=
self
.
train_shape
else
:
target_data
=
self
.
validation_data
target_labels
=
self
.
validation_labels
shape
=
self
.
validation_shape
data
=
numpy
.
zeros
(
shape
=
shape
,
dtype
=
'float32'
)
data_p
=
numpy
.
zeros
(
shape
=
shape
,
dtype
=
'float32'
)
labels_siamese
=
numpy
.
zeros
(
shape
=
shape
[
0
],
dtype
=
'float32'
)
genuine
=
True
for
i
in
range
(
shape
[
0
]):
data
[
i
,
...],
data_p
[
i
,
...]
=
get_genuine_or_not
(
target_data
,
target_labels
,
genuine
=
genuine
)
if
zero_one_labels
:
labels_siamese
[
i
]
=
not
genuine
else
:
labels_siamese
[
i
]
=
-
1
if
genuine
else
+
1
genuine
=
not
genuine
return
data
,
data_p
,
labels_siamese
def
get_triplet
(
self
,
n_labels
,
n_triplets
=
1
,
is_target_set_train
=
True
):
"""
Get a triplet
**Parameters**
is_target_set_train: Defining the target set to get the batch
**Return**
"""
def
get_one_triplet
(
input_data
,
input_labels
):
# Getting a pair of clients
index
=
numpy
.
random
.
choice
(
n_labels
,
2
,
replace
=
False
)
label_positive
=
index
[
0
]
label_negative
=
index
[
1
]
# Getting the indexes of the data from a particular client
indexes
=
numpy
.
where
(
input_labels
==
index
[
0
])[
0
]
numpy
.
random
.
shuffle
(
indexes
)
# Picking a positive pair
data_anchor
=
input_data
[
indexes
[
0
],
:,
:,
:]
data_positive
=
input_data
[
indexes
[
1
],
:,
:,
:]
# Picking a negative sample
indexes
=
numpy
.
where
(
input_labels
==
index
[
1
])[
0
]
numpy
.
random
.
shuffle
(
indexes
)
data_negative
=
input_data
[
indexes
[
0
],
:,
:,
:]
return
data_anchor
,
data_positive
,
data_negative
,
label_positive
,
label_positive
,
label_negative
if
is_target_set_train
:
target_data
=
self
.
train_data
target_labels
=
self
.
train_labels
else
:
target_data
=
self
.
validation_data
target_labels
=
self
.
validation_labels
c
=
target_data
.
shape
[
3
]
w
=
target_data
.
shape
[
1
]
h
=
target_data
.
shape
[
2
]
data_a
=
numpy
.
zeros
(
shape
=
(
n_triplets
,
w
,
h
,
c
),
dtype
=
'float32'
)
data_p
=
numpy
.
zeros
(
shape
=
(
n_triplets
,
w
,
h
,
c
),
dtype
=
'float32'
)
data_n
=
numpy
.
zeros
(
shape
=
(
n_triplets
,
w
,
h
,
c
),
dtype
=
'float32'
)
labels_a
=
numpy
.
zeros
(
shape
=
n_triplets
,
dtype
=
'float32'
)
labels_p
=
numpy
.
zeros
(
shape
=
n_triplets
,
dtype
=
'float32'
)
labels_n
=
numpy
.
zeros
(
shape
=
n_triplets
,
dtype
=
'float32'
)
for
i
in
range
(
n_triplets
):
data_a
[
i
,
:,
:,
:],
data_p
[
i
,
:,
:,
:],
data_n
[
i
,
:,
:,
:],
\
labels_a
[
i
],
labels_p
[
i
],
labels_n
[
i
]
=
\
get_one_triplet
(
target_data
,
target_labels
)
return
data_a
,
data_p
,
data_n
,
labels_a
,
labels_p
,
labels_n
bob/learn/tensorflow/data/TextDataShuffler.py
View file @
bc57c932
...
...
@@ -21,10 +21,8 @@ class TextDataShuffler(BaseDataShuffler):
def
__init__
(
self
,
data
,
labels
,
input_shape
,
input_dtype
=
"float64"
,
perc_train
=
0.9
,
scale
=
True
,
train_batch_size
=
1
,
validation_batch_size
=
300
):
batch_size
=
1
):
"""
Shuffler that deal with file list
...
...
@@ -48,10 +46,8 @@ class TextDataShuffler(BaseDataShuffler):
labels
=
labels
,
input_shape
=
input_shape
,
input_dtype
=
input_dtype
,
perc_train
=
perc_train
,
scale
=
scale
,
train_batch_size
=
train_batch_size
,
validation_batch_size
=
validation_batch_size
batch_size
=
batch_size
)
def
load_from_file
(
self
,
file_name
,
shape
):
...
...
@@ -64,38 +60,27 @@ class TextDataShuffler(BaseDataShuffler):
return
data
def
get_batch
(
self
,
train_dataset
=
True
):
if
train_dataset
:
batch_size
=
self
.
train_batch_size
shape
=
self
.
train_shape
files_names
=
self
.
train_data
label
=
self
.
train_labels
else
:
batch_size
=
self
.
validation_batch_size
shape
=
self
.
validation_shape
files_names
=
self
.
validation_data
label
=
self
.
validation_labels
def
get_batch
(
self
):
# Shuffling samples
indexes
=
numpy
.
array
(
range
(
files_names
.
shape
[
0
]))
indexes
=
numpy
.
array
(
range
(
self
.
data
.
shape
[
0
]))
numpy
.
random
.
shuffle
(
indexes
)
selected_data
=
numpy
.
zeros
(
shape
=
shape
)
for
i
in
range
(
batch_size
):
selected_data
=
numpy
.
zeros
(
shape
=
self
.
shape
)
for
i
in
range
(
self
.
batch_size
):
file_name
=
files_names
[
indexes
[
i
]]
data
=
self
.
load_from_file
(
file_name
,
shape
)
file_name
=
self
.
data
[
indexes
[
i
]]
data
=
self
.
load_from_file
(
file_name
,
self
.
shape
)
selected_data
[
i
,
...]
=
data
if
self
.
scale
:
selected_data
[
i
,
...]
*=
self
.
scale_value
selected_labels
=
label
[
indexes
[
0
:
batch_size
]]
selected_labels
=
self
.
label
s
[
indexes
[
0
:
self
.
batch_size
]]
return
selected_data
.
astype
(
"float32"
),
selected_labels
def
get_pair
(
self
,
train_dataset
=
True
,
zero_one_labels
=
True
):
def
get_pair
(
self
,
zero_one_labels
=
True
):
"""
Get a random pair of samples
...
...
@@ -105,24 +90,15 @@ class TextDataShuffler(BaseDataShuffler):
**Return**
"""
if
train_dataset
:
target_data
=
self
.
train_data
target_labels
=
self
.
train_labels
shape
=
self
.
train_shape
else
:
target_data
=
self
.
validation_data
target_labels
=
self
.
validation_labels
shape
=
self
.
validation_shape
data
=
numpy
.
zeros
(
shape
=
shape
,
dtype
=
'float32'
)
data_p
=
numpy
.
zeros
(
shape
=
shape
,
dtype
=
'float32'
)
labels_siamese
=
numpy
.
zeros
(
shape
=
shape
[
0
],
dtype
=
'float32'
)
data
=
numpy
.
zeros
(
shape
=
self
.
shape
,
dtype
=
'float32'
)
data_p
=
numpy
.
zeros
(
shape
=
self
.
shape
,
dtype
=
'float32'
)
labels_siamese
=
numpy
.
zeros
(
shape
=
self
.
shape
[
0
],
dtype
=
'float32'
)
genuine
=
True
for
i
in
range
(
shape
[
0
]):
file_name
,
file_name_p
=
self
.
get_genuine_or_not
(
target_data
,
target_
labels
,
genuine
=
genuine
)
data
[
i
,
...]
=
self
.
load_from_file
(
str
(
file_name
),
shape
)
data_p
[
i
,
...]
=
self
.
load_from_file
(
str
(
file_name_p
),
shape
)
for
i
in
range
(
self
.
shape
[
0
]):
file_name
,
file_name_p
=
self
.
get_genuine_or_not
(
self
.
data
,
self
.
labels
,
genuine
=
genuine
)
data
[
i
,
...]
=
self
.
load_from_file
(
str
(
file_name
),
self
.
shape
)
data_p
[
i
,
...]
=
self
.
load_from_file
(
str
(
file_name_p
),
self
.
shape
)
if
zero_one_labels
:
labels_siamese
[
i
]
=
not
genuine
...
...
bob/learn/tensorflow/data/TextPairDataShuffler.py
deleted
100644 → 0
View file @
86f2c9cf
#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Tiago de Freitas Pereira <tiago.pereira@idiap.ch>
# @date: Wed 11 May 2016 09:39:36 CEST
import
numpy
from
.TextDataShuffler
import
TextDataShuffler
class
TextPairDataShuffler
(
TextDataShuffler
):
def
__init__
(
self
,
data
,
labels
,
input_shape
,
perc_train
=
0.9
,
scale
=
True
,
train_batch_size
=
1
,
validation_batch_size
=
300
):
"""
The class provide some functionalities for shuffling data
**Parameters**
data:
"""
data
=
data
labels
=
labels
input_shape
=
input_shape
perc_train
=
perc_train
scale
=
scale
train_batch_size
=
train_batch_size
validation_batch_size
=
validation_batch_size
super
(
TextPairDataShuffler
,
self
).
__init__
(
data
,
labels
,
input_shape
=
input_shape
,
perc_train
=
perc_train
,
scale
=
scale
,
train_batch_size
=
train_batch_size
*
2
,
validation_batch_size
=
validation_batch_size
)
def
get_pair
(
self
,
train_dataset
=
True
,
zero_one_labels
=
True
):
"""
Get a random pair of samples
**Parameters**
is_target_set_train: Defining the target set to get the batch
**Return**
"""
def
get_genuine_or_not
(
input_data
,
input_labels
,
genuine
=
True
):
if
genuine
:
# TODO: THIS KEY SELECTION NEEDS TO BE MORE EFFICIENT
# Getting a client
index
=
numpy
.
random
.
randint
(
self
.
total_labels
)
# Getting the indexes of the data from a particular client
indexes
=
numpy
.
where
(
input_labels
==
index
)[
0
]
numpy
.
random
.
shuffle
(
indexes
)
# Picking a pair
data
=
input_data
[
indexes
[
0
]]
data_p
=
input_data
[
indexes
[
1
]]
else
:
# Picking a pair from different clients
index
=
numpy
.
random
.
choice
(
self
.
total_labels
,
2
,
replace
=
False
)
# Getting the indexes of the two clients
indexes
=
numpy
.
where
(
input_labels
==
index
[
0
])[
0
]
indexes_p
=
numpy
.
where
(
input_labels
==
index
[
1
])[
0
]
numpy
.
random
.
shuffle
(
indexes
)
numpy
.
random
.
shuffle
(
indexes_p
)
# Picking a pair
data
=
input_data
[
indexes
[
0
]]
data_p
=
input_data
[
indexes_p
[
0
]]
return
data
,
data_p
if
train_dataset
:
target_data
=
self
.
train_data
target_labels
=
self
.
train_labels
shape
=
self
.
train_shape
else
:
target_data
=
self
.
validation_data
target_labels
=
self
.
validation_labels
shape
=
self
.
validation_shape
data
=
numpy
.
zeros
(
shape
=
shape
,
dtype
=
'float32'
)
data_p
=
numpy
.
zeros
(
shape
=
shape
,
dtype
=
'float32'
)