Commit da387a3c authored by Tiago de Freitas Pereira's avatar Tiago de Freitas Pereira
Browse files

[py] Created an option that sets GPU

parent 13137aea
......@@ -21,9 +21,10 @@ class ArcFaceInsightFace(TransformerMixin, BaseEstimator):
"""
def __init__(self, **kwargs):
def __init__(self, use_gpu=False, **kwargs):
super().__init__(**kwargs)
self.model = None
self.use_gpu = use_gpu
internal_path = pkg_resources.resource_filename(
__name__, os.path.join("data", "arcface_insightface"),
......@@ -54,8 +55,7 @@ class ArcFaceInsightFace(TransformerMixin, BaseEstimator):
sym = all_layers["fc1_output"]
# LOADING CHECKPOINT
ctx = mx.cpu()
ctx = mx.gpu() if self.use_gpu else mx.cpu()
model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
data_shape = (1, 3, 112, 112)
model.bind(data_shapes=[("data", data_shape)])
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment