Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.ip.tensorflow_extractor
Commits
ff821c8a
Commit
ff821c8a
authored
Jun 07, 2019
by
Amir MOHAMMADI
Browse files
Merge branch 'facenet' into 'master'
Improve graph and session handling in facenet class See merge request
!13
parents
2aa5ca35
a34446f5
Pipeline
#30799
passed with stages
in 17 minutes and 48 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
bob/ip/tensorflow_extractor/FaceNet.py
View file @
ff821c8a
...
...
@@ -115,23 +115,24 @@ class FaceNet(object):
# code from https://github.com/davidsandberg/facenet
model_exp
=
os
.
path
.
expanduser
(
self
.
model_path
)
if
(
os
.
path
.
isfile
(
model_exp
)):
logger
.
info
(
'Model filename: %s'
%
model_exp
)
with
tf
.
gfile
.
FastGFile
(
model_exp
,
'rb'
)
as
f
:
graph_def
=
tf
.
GraphDef
()
graph_def
.
ParseFromString
(
f
.
read
())
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
else
:
logger
.
info
(
'Model directory: %s'
%
model_exp
)
meta_file
,
ckpt_file
=
get_model_filenames
(
model_exp
)
logger
.
info
(
'Metagraph file: %s'
%
meta_file
)
logger
.
info
(
'Checkpoint file: %s'
%
ckpt_file
)
saver
=
tf
.
train
.
import_meta_graph
(
os
.
path
.
join
(
model_exp
,
meta_file
))
saver
.
restore
(
tf
.
get_default_session
(),
os
.
path
.
join
(
model_exp
,
ckpt_file
))
with
self
.
graph
.
as_default
():
if
(
os
.
path
.
isfile
(
model_exp
)):
logger
.
info
(
'Model filename: %s'
%
model_exp
)
with
tf
.
gfile
.
FastGFile
(
model_exp
,
'rb'
)
as
f
:
graph_def
=
tf
.
GraphDef
()
graph_def
.
ParseFromString
(
f
.
read
())
tf
.
import_graph_def
(
graph_def
,
name
=
''
)
else
:
logger
.
info
(
'Model directory: %s'
%
model_exp
)
meta_file
,
ckpt_file
=
get_model_filenames
(
model_exp
)
logger
.
info
(
'Metagraph file: %s'
%
meta_file
)
logger
.
info
(
'Checkpoint file: %s'
%
ckpt_file
)
saver
=
tf
.
train
.
import_meta_graph
(
os
.
path
.
join
(
model_exp
,
meta_file
))
saver
.
restore
(
self
.
session
,
os
.
path
.
join
(
model_exp
,
ckpt_file
))
# Get input and output tensors
self
.
images_placeholder
=
self
.
graph
.
get_tensor_by_name
(
"input:0"
)
self
.
embeddings
=
self
.
graph
.
get_tensor_by_name
(
self
.
layer_name
)
...
...
@@ -142,8 +143,8 @@ class FaceNet(object):
def
__call__
(
self
,
img
):
images
=
self
.
_check_feature
(
img
)
if
self
.
session
is
None
:
self
.
session
=
tf
.
InteractiveSession
()
self
.
graph
=
tf
.
get_default_
graph
(
)
self
.
graph
=
tf
.
Graph
()
self
.
session
=
tf
.
Session
(
graph
=
self
.
graph
)
if
self
.
embeddings
is
None
:
self
.
load_model
()
feed_dict
=
{
self
.
images_placeholder
:
images
,
...
...
@@ -152,9 +153,6 @@ class FaceNet(object):
self
.
embeddings
,
feed_dict
=
feed_dict
)
return
features
.
flatten
()
def
__del__
(
self
):
tf
.
reset_default_graph
()
@
staticmethod
def
get_rcvariable
():
"""
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment