Commit 887d9e54 authored by Guillaume HEUSCH's avatar Guillaume HEUSCH

[extractor] modified the code to load the model, to be consistent with MCDeepPixBiS

parent 705fc0c8
Pipeline #35593 passed with stage
in 10 minutes and 36 seconds
...@@ -52,24 +52,17 @@ class DeepPixBiSExtractor(Extractor): ...@@ -52,24 +52,17 @@ class DeepPixBiSExtractor(Extractor):
logger.debug("No pretrained file provided") logger.debug("No pretrained file provided")
pass pass
else: else:
# With the new training
logger.debug('Starting to load the pretrained PAD model') logger.debug('Starting to load the pretrained PAD model')
try: try:
cp = torch.load(model_file) cp = torch.load(model_file)
except: except:
try: raise ValueError('Failed to load the model file : {}'.format(model_file))
cp= torch.load(model_file,map_location=lambda storage,loc:storage)
except:
raise ValueError('Could not load the model')
if 'state_dict' in cp: if 'state_dict' in cp:
self.network.load_state_dict(cp['state_dict']) self.network.load_state_dict(cp['state_dict'])
else: ## check this part else:
self.network.load_state_dict(cp) raise ValueError('Failed to load the state_dict for model file: {}'.format(model_file))
logger.debug('Loaded the pretrained PAD model') logger.debug('Loaded the pretrained PAD model')
self.network.eval() self.network.eval()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment