Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
bob
bob.learn.tensorflow
Commits
1e21e10c
Commit
1e21e10c
authored
Sep 08, 2020
by
Amir MOHAMMADI
Browse files
remove the extractors folder
parent
a28815cd
Changes
4
Hide whitespace changes
Inline
Side-by-side
bob/learn/tensorflow/extractors/Base.py
deleted
100644 → 0
View file @
a28815cd
import
tensorflow
as
tf
import
os
import
numpy
as
np
import
logging
logger
=
logging
.
getLogger
(
__name__
)
def
normalize_checkpoint_path
(
path
):
if
os
.
path
.
splitext
(
path
)[
1
]
==
".meta"
:
filename
=
os
.
path
.
splitext
(
path
)[
0
]
elif
os
.
path
.
isdir
(
path
):
filename
=
tf
.
train
.
latest_checkpoint
(
path
)
else
:
filename
=
path
return
filename
class
Base
:
def
__init__
(
self
,
output_name
,
input_shape
,
checkpoint
,
scopes
,
input_transform
=
None
,
output_transform
=
None
,
input_dtype
=
'float32'
,
extra_feed
=
None
,
**
kwargs
):
self
.
output_name
=
output_name
self
.
input_shape
=
input_shape
self
.
checkpoint
=
normalize_checkpoint_path
(
checkpoint
)
self
.
scopes
=
scopes
self
.
input_transform
=
input_transform
self
.
output_transform
=
output_transform
self
.
input_dtype
=
input_dtype
self
.
extra_feed
=
extra_feed
self
.
session
=
None
super
().
__init__
(
**
kwargs
)
def
load
(
self
):
self
.
session
=
tf
.
Session
(
graph
=
tf
.
Graph
())
with
self
.
session
.
as_default
(),
self
.
session
.
graph
.
as_default
():
self
.
input
=
data
=
tf
.
placeholder
(
self
.
input_dtype
,
self
.
input_shape
)
if
self
.
input_transform
is
not
None
:
data
=
self
.
input_transform
(
data
)
self
.
output
=
self
.
get_output
(
data
,
tf
.
estimator
.
ModeKeys
.
PREDICT
)
if
self
.
output_transform
is
not
None
:
self
.
output
=
self
.
output_transform
(
self
.
output
)
tf
.
train
.
init_from_checkpoint
(
ckpt_dir_or_file
=
self
.
checkpoint
,
assignment_map
=
self
.
scopes
,
)
# global_variables_initializer must run after init_from_checkpoint
self
.
session
.
run
(
tf
.
global_variables_initializer
())
logger
.
info
(
'Restored the model from %s'
,
self
.
checkpoint
)
def
__call__
(
self
,
data
):
if
self
.
session
is
None
:
self
.
load
()
data
=
np
.
ascontiguousarray
(
data
,
dtype
=
self
.
input_dtype
)
feed_dict
=
{
self
.
input
:
data
}
if
self
.
extra_feed
is
not
None
:
feed_dict
.
update
(
self
.
extra_feed
)
return
self
.
session
.
run
(
self
.
output
,
feed_dict
=
feed_dict
)
def
get_output
(
self
,
data
,
mode
):
raise
NotImplementedError
()
bob/learn/tensorflow/extractors/Estimator.py
deleted
100644 → 0
View file @
a28815cd
import
tensorflow
as
tf
from
.Base
import
Base
class
Estimator
(
Base
):
def
__init__
(
self
,
estimator
,
**
kwargs
):
self
.
estimator
=
estimator
kwargs
[
'checkpoint'
]
=
kwargs
.
get
(
'checkpoint'
,
estimator
.
model_dir
)
super
().
__init__
(
**
kwargs
)
def
get_output
(
self
,
data
,
mode
):
features
=
{
'data'
:
data
,
'key'
:
tf
.
constant
([
'key'
])}
self
.
estimator_spec
=
self
.
estimator
.
_call_model_fn
(
features
,
None
,
mode
,
None
)
self
.
end_points
=
self
.
estimator
.
end_points
return
self
.
end_points
[
self
.
output_name
]
bob/learn/tensorflow/extractors/Generic.py
deleted
100644 → 0
View file @
a28815cd
from
.Base
import
Base
class
Generic
(
Base
):
def
__init__
(
self
,
architecture
,
**
kwargs
):
self
.
architecture
=
architecture
super
().
__init__
(
**
kwargs
)
def
get_output
(
self
,
data
,
mode
):
self
.
end_points
=
self
.
architecture
(
data
,
mode
=
mode
)[
1
]
return
self
.
end_points
[
self
.
output_name
]
bob/learn/tensorflow/extractors/__init__.py
deleted
100644 → 0
View file @
a28815cd
from
.Base
import
Base
,
normalize_checkpoint_path
from
.Generic
import
Generic
from
.Estimator
import
Estimator
# gets sphinx autodoc done right - don't remove it
def
__appropriate__
(
*
args
):
"""Says object was actually declared here, an not on the import module.
Parameters:
*args: An iterable of objects to modify
Resolves `Sphinx referencing issues
<https://github.com/sphinx-doc/sphinx/issues/3048>`
"""
for
obj
in
args
:
obj
.
__module__
=
__name__
__appropriate__
(
Base
,
Generic
,
Estimator
,
)
__all__
=
[
_
for
_
in
dir
()
if
not
_
.
startswith
(
'_'
)]
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment