Skip to content
Snippets Groups Projects
Commit 19f0e884 authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Renamed the class MultiNetPatchClassifier to MultiNetPatchExtractor

parent 4886bbad
No related branches found
No related tags found
1 merge request!3Added MultiNetPatchClassifier extractor and utils, temp fix in LightCNN9
...@@ -18,7 +18,7 @@ from bob.ip.pytorch_extractor.utils import net_forward ...@@ -18,7 +18,7 @@ from bob.ip.pytorch_extractor.utils import net_forward
# ============================================================================= # =============================================================================
# Main body: # Main body:
class MultiNetPatchClassifier(Extractor, object): class MultiNetPatchExtractor(Extractor, object):
""" """
This class is designed to pass a set of patches through a possibly multiple This class is designed to pass a set of patches through a possibly multiple
networks and compute a feature vector combining outputs of all networks. networks and compute a feature vector combining outputs of all networks.
...@@ -88,7 +88,7 @@ class MultiNetPatchClassifier(Extractor, object): ...@@ -88,7 +88,7 @@ class MultiNetPatchClassifier(Extractor, object):
Init method. Init method.
""" """
super(MultiNetPatchClassifier, self).__init__(config_file = config_file, super(MultiNetPatchExtractor, self).__init__(config_file = config_file,
config_group = config_group, config_group = config_group,
model_file = model_file, model_file = model_file,
function_name = function_name, function_name = function_name,
......
from .CNN8 import CNN8Extractor from .CNN8 import CNN8Extractor
from .CasiaNet import CasiaNetExtractor from .CasiaNet import CasiaNetExtractor
from .LightCNN9 import LightCNN9Extractor from .LightCNN9 import LightCNN9Extractor
from .MultiNetPatchClassifier import MultiNetPatchClassifier from .MultiNetPatchExtractor import MultiNetPatchExtractor
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
def __appropriate__(*args): def __appropriate__(*args):
...@@ -22,7 +22,7 @@ __appropriate__( ...@@ -22,7 +22,7 @@ __appropriate__(
CNN8Extractor, CNN8Extractor,
CasiaNetExtractor, CasiaNetExtractor,
LightCNN9Extractor, LightCNN9Extractor,
MultiNetPatchClassifier, MultiNetPatchExtractor,
) )
# gets sphinx autodoc done right - don't remove it # gets sphinx autodoc done right - don't remove it
......
...@@ -46,12 +46,12 @@ def test_lightcnn9(): ...@@ -46,12 +46,12 @@ def test_lightcnn9():
output = extractor(data) output = extractor(data)
assert output.shape[0] == 256 assert output.shape[0] == 256
def test_multi_net_patch_classifier(): def test_multi_net_patch_extractor():
""" """
Test the MultiNetPatchClassifier extractor class. Test the MultiNetPatchExtractor extractor class.
""" """
from bob.ip.pytorch_extractor import MultiNetPatchClassifier from bob.ip.pytorch_extractor import MultiNetPatchExtractor
# ========================================================================= # =========================================================================
# prepare the test data: # prepare the test data:
...@@ -91,13 +91,13 @@ def test_multi_net_patch_classifier(): ...@@ -91,13 +91,13 @@ def test_multi_net_patch_classifier():
PATCH_RESHAPE_PARAMETERS = [3, 128, 128] # reshape vectorized patches to this dimensions before passing to the Network PATCH_RESHAPE_PARAMETERS = [3, 128, 128] # reshape vectorized patches to this dimensions before passing to the Network
image_extractor = MultiNetPatchClassifier(config_file = CONFIG_FILE, image_extractor = MultiNetPatchExtractor(config_file = CONFIG_FILE,
config_group = CONFIG_GROUP, config_group = CONFIG_GROUP,
model_file = MODEL_FILE, model_file = MODEL_FILE,
function_name = FUNCTION_NAME, function_name = FUNCTION_NAME,
function_kwargs = FUNCTION_KWARGS, function_kwargs = FUNCTION_KWARGS,
patches_num = PATCHES_NUM, patches_num = PATCHES_NUM,
patch_reshape_parameters = PATCH_RESHAPE_PARAMETERS) patch_reshape_parameters = PATCH_RESHAPE_PARAMETERS)
# pass through encoder only, compute latent vector: # pass through encoder only, compute latent vector:
latent_vector = image_extractor(patch_flat) latent_vector = image_extractor(patch_flat)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment