Skip to content
Snippets Groups Projects
Commit 66948f20 authored by Tiago Pereira's avatar Tiago Pereira
Browse files

Added keyword when loading from file

parent e90a4429
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -159,10 +159,10 @@ class SiameseTrainer(Trainer): ...@@ -159,10 +159,10 @@ class SiameseTrainer(Trainer):
# Creating the variables # Creating the variables
tf.global_variables_initializer().run(session=self.session) tf.global_variables_initializer().run(session=self.session)
def create_network_from_file(self, model_from_file): def create_network_from_file(self, model_from_file, clear_devices=True):
#saver = self.architecture.load(self.model_from_file, clear_devices=False) #saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta") self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, model_from_file) self.saver.restore(self.session, model_from_file)
# Loading the graph from the graph pointers # Loading the graph from the graph pointers
......
...@@ -164,7 +164,7 @@ class Trainer(object): ...@@ -164,7 +164,7 @@ class Trainer(object):
# Creating the variables # Creating the variables
tf.global_variables_initializer().run(session=self.session) tf.global_variables_initializer().run(session=self.session)
def create_network_from_file(self, file_name): def create_network_from_file(self, file_name, clear_devices=True):
""" """
Bootstrap a graph from a checkpoint Bootstrap a graph from a checkpoint
...@@ -172,7 +172,7 @@ class Trainer(object): ...@@ -172,7 +172,7 @@ class Trainer(object):
file_name: Name of of the checkpoing file_name: Name of of the checkpoing
""" """
self.saver = tf.train.import_meta_graph(file_name + ".meta") self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, file_name) self.saver.restore(self.session, file_name)
# Loading training graph # Loading training graph
......
...@@ -178,10 +178,10 @@ class TripletTrainer(Trainer): ...@@ -178,10 +178,10 @@ class TripletTrainer(Trainer):
# Creating the variables # Creating the variables
tf.global_variables_initializer().run(session=self.session) tf.global_variables_initializer().run(session=self.session)
def create_network_from_file(self, model_from_file): def create_network_from_file(self, model_from_file, clear_devices=True):
#saver = self.architecture.load(self.model_from_file, clear_devices=False) #saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta") self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, model_from_file) self.saver.restore(self.session, model_from_file)
# Loading the graph from the graph pointers # Loading the graph from the graph pointers
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment