Skip to content
Snippets Groups Projects
Commit 6fcf9c5e authored by Olegs NIKISINS's avatar Olegs NIKISINS
Browse files

Added unit test for MultiNetPatchClassifier extractor, and pre-trained model for AE

parent ccedc9a3
No related branches found
No related tags found
1 merge request!3Added MultiNetPatchClassifier extractor and utils, temp fix in LightCNN9
......@@ -45,3 +45,79 @@ def test_lightcnn9():
data = numpy.random.rand(128, 128).astype("float32")
output = extractor(data)
assert output.shape[0] == 256
def test_multi_net_patch_classifier():
"""
Test the MultiNetPatchClassifier extractor class.
"""
from bob.ip.pytorch_extractor import MultiNetPatchClassifier
# =========================================================================
# prepare the test data:
patch_2d = numpy.repeat(numpy.expand_dims(numpy.sin(numpy.arange(0,12.8,0.1)), axis=0), 128, axis=0)
patch = numpy.uint8((numpy.stack([patch_2d, patch_2d.transpose(), -patch_2d])+1)*255/2.)
# flatten the 3D test patch:
patch_flat = numpy.expand_dims(patch.flatten(), axis=0)
# =========================================================================
# test the extractor:
CONFIG_FILE = "autoencoder/net1_celeba.py" # config containing an instance of Composed Transform and a Network class to be used in feature extractor
CONFIG_GROUP = "bob.learn.pytorch.config"
# use specific/unique model for each patch. Models pre-trained on CelebA and fine-tuned (3 layers) on BATL:
MODEL_FILE = [pkg_resources.resource_filename('bob.ip.pytorch_extractor',
'test_data/conv_ae_model_pretrain_celeba_tune_batl_full_face.pth')]
FUNCTION_NAME = "net_forward" # function to be used to extract features given input patch
def _prediction_function(local_model, x): # use only encoder from Network loaded from above config.
x = local_model.encoder(x)
return x
# kwargs for function defined by FUNCTION_NAME constant:
FUNCTION_KWARGS = {}
FUNCTION_KWARGS["config_file"] = CONFIG_FILE
FUNCTION_KWARGS["config_group"] = CONFIG_GROUP
FUNCTION_KWARGS["model_file"] = MODEL_FILE
FUNCTION_KWARGS["invert_scores_flag"] = False
FUNCTION_KWARGS["prediction_function"] = _prediction_function
FUNCTION_KWARGS["color_input_flag"] = True
PATCHES_NUM = [0] # patches to be used in the feature extraction process
PATCH_RESHAPE_PARAMETERS = [3, 128, 128] # reshape vectorized patches to this dimensions before passing to the Network
image_extractor = MultiNetPatchClassifier(config_file = CONFIG_FILE,
config_group = CONFIG_GROUP,
model_file = MODEL_FILE,
function_name = FUNCTION_NAME,
function_kwargs = FUNCTION_KWARGS,
patches_num = PATCHES_NUM,
patch_reshape_parameters = PATCH_RESHAPE_PARAMETERS)
# pass through encoder only, compute latent vector:
latent_vector = image_extractor(patch_flat)
# pass through AE, compute reconstructed image:
image_extractor.function_kwargs['prediction_function'] = None
reconstructed = image_extractor(patch_flat).reshape(PATCH_RESHAPE_PARAMETERS)
# test:
assert latent_vector.shape == (1296,)
assert reconstructed.shape == (3, 128, 128)
# # for visualization/debugging only:
# import matplotlib.pyplot as plt
# import bob.io.image
#
# plt.figure()
# plt.imshow(bob.io.image.to_matplotlib(patch))
# plt.show()
#
# plt.figure()
# plt.imshow(bob.io.image.to_matplotlib(reconstructed))
# plt.show()
File added
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment