Skip to content
Snippets Groups Projects
Commit 156262a5 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH
Browse files

[extractor] fixed DRGANLight extractor

parent 2890af32
Branches
No related tags found
No related merge requests found
......@@ -14,7 +14,7 @@ from torch.autograd import Variable
from bob.learn.pytorch.architectures import DRGAN_encoder as drgan_encoder
class LuanExtractor(Extractor):
class DRGANLight(Extractor):
"""
**Parameters:**
......@@ -27,7 +27,7 @@ class LuanExtractor(Extractor):
self.latent_dim = 320
self.image_size = (3, 64, 64)
self.encoder = drgan_encoder(input_image.shape, latent_dim)
self.encoder = drgan_encoder(self.image_size, self.latent_dim)
# image pre-processing
self.to_tensor = transforms.ToTensor()
......@@ -53,14 +53,16 @@ class LuanExtractor(Extractor):
input_image = self.to_tensor(input_image)
input_image = self.norm(input_image)
input_image = input_image.unsqueeze(0)
encoded_id = encoder.forward(Variable(input_image))
print encoded_id
import sys
sys.exit()
encoded_id = self.encoder.forward(Variable(input_image))
encoded_id = encoded_id.data.numpy()
encoded_id = encoded_id[0, :, 0, 0]
#print encoded_id.shape
#import sys
#sys.exit()
return encoded_id
# re-define the train function to get it non-documented
def train(*args, **kwargs): raise NotImplementedError("This function is not implemented and should not be called.")
def load(self, extractor_file):
encoder.load_state_dict(torch.load(extractor_file, map_location=lambda storage, loc: storage))
self.encoder.load_state_dict(torch.load(extractor_file, map_location=lambda storage, loc: storage))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment