Skip to content
Snippets Groups Projects

Gan

Closed Guillaume HEUSCH requested to merge gan into master
1 file
+ 6
6
Compare changes
  • Side-by-side
  • Inline
@@ -65,7 +65,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
@@ -65,7 +65,7 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
self.pickle_architecture = pickle.dumps(self.sequence_net)
self.pickle_architecture = pickle.dumps(self.sequence_net)
self.deployment_shape = shape
self.deployment_shape = shape
def compute_graph(self, input_data, feature_layer=None, training=True):
def compute_graph(self, input_data, feature_layer=None, training=True, scope=None):
"""Given the current network, return the Tensorflow graph
"""Given the current network, return the Tensorflow graph
**Parameter**
**Parameter**
@@ -77,18 +77,18 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
@@ -77,18 +77,18 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
training: If `True` will generating the graph for training
training: If `True` will generating the graph for training
"""
"""
input_offset = input_data
input_offset = input_data
for k in self.sequence_net.keys():
for k in self.sequence_net.keys():
current_layer = self.sequence_net[k]
current_layer = self.sequence_net[k]
if training or not isinstance(current_layer, Dropout):
if training or not isinstance(current_layer, Dropout):
current_layer.create_variables(input_offset)
current_layer.create_variables(input_offset, scope=scope)
input_offset = current_layer.get_graph(training_phase=training)
input_offset = current_layer.get_graph(training_phase=training)
if feature_layer is not None and k == feature_layer:
if feature_layer is not None and k == feature_layer:
return input_offset
return input_offset
 
return input_offset
return input_offset
def compute_inference_graph(self, feature_layer=None):
def compute_inference_graph(self, feature_layer=None):
@@ -148,9 +148,9 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
@@ -148,9 +148,9 @@ class SequenceNetwork(six.with_metaclass(abc.ABCMeta, object)):
variables[self.sequence_net[k].b.name] = self.sequence_net[k].b
variables[self.sequence_net[k].b.name] = self.sequence_net[k].b
# Dumping batch norm variables
# Dumping batch norm variables
if self.sequence_net[k].batch_norm:
#if self.sequence_net[k].batch_norm:
variables[self.sequence_net[k].beta.name] = self.sequence_net[k].beta
#variables[self.sequence_net[k].beta.name] = self.sequence_net[k].beta
variables[self.sequence_net[k].gamma.name] = self.sequence_net[k].gamma
#variables[self.sequence_net[k].gamma.name] = self.sequence_net[k].gamma
#variables[self.sequence_net[k].mean.name] = self.sequence_net[k].mean
#variables[self.sequence_net[k].mean.name] = self.sequence_net[k].mean
#variables[self.sequence_net[k].var.name] = self.sequence_net[k].var
#variables[self.sequence_net[k].var.name] = self.sequence_net[k].var
Loading