Commit 66948f20 authored by Tiago Pereira's avatar Tiago Pereira

Added keyword when loading from file

parent e90a4429
Pipeline #11221 passed with stages
in 12 minutes and 14 seconds
...@@ -159,10 +159,10 @@ class SiameseTrainer(Trainer): ...@@ -159,10 +159,10 @@ class SiameseTrainer(Trainer):
# Creating the variables # Creating the variables
tf.global_variables_initializer().run(session=self.session) tf.global_variables_initializer().run(session=self.session)
def create_network_from_file(self, model_from_file): def create_network_from_file(self, model_from_file, clear_devices=True):
#saver = self.architecture.load(self.model_from_file, clear_devices=False) #saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta") self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, model_from_file) self.saver.restore(self.session, model_from_file)
# Loading the graph from the graph pointers # Loading the graph from the graph pointers
......
...@@ -164,7 +164,7 @@ class Trainer(object): ...@@ -164,7 +164,7 @@ class Trainer(object):
# Creating the variables # Creating the variables
tf.global_variables_initializer().run(session=self.session) tf.global_variables_initializer().run(session=self.session)
def create_network_from_file(self, file_name): def create_network_from_file(self, file_name, clear_devices=True):
""" """
Bootstrap a graph from a checkpoint Bootstrap a graph from a checkpoint
...@@ -172,7 +172,7 @@ class Trainer(object): ...@@ -172,7 +172,7 @@ class Trainer(object):
file_name: Name of of the checkpoing file_name: Name of of the checkpoing
""" """
self.saver = tf.train.import_meta_graph(file_name + ".meta") self.saver = tf.train.import_meta_graph(file_name + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, file_name) self.saver.restore(self.session, file_name)
# Loading training graph # Loading training graph
......
...@@ -178,10 +178,10 @@ class TripletTrainer(Trainer): ...@@ -178,10 +178,10 @@ class TripletTrainer(Trainer):
# Creating the variables # Creating the variables
tf.global_variables_initializer().run(session=self.session) tf.global_variables_initializer().run(session=self.session)
def create_network_from_file(self, model_from_file): def create_network_from_file(self, model_from_file, clear_devices=True):
#saver = self.architecture.load(self.model_from_file, clear_devices=False) #saver = self.architecture.load(self.model_from_file, clear_devices=False)
self.saver = tf.train.import_meta_graph(model_from_file + ".meta") self.saver = tf.train.import_meta_graph(model_from_file + ".meta", clear_devices=clear_devices)
self.saver.restore(self.session, model_from_file) self.saver.restore(self.session, model_from_file)
# Loading the graph from the graph pointers # Loading the graph from the graph pointers
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment