mirror of
https://github.com/igv/FSRCNN-TensorFlow.git
synced 2025-12-15 00:58:23 +08:00
model.py: also dump alpha variables
In addition, I've also fixed a bug where the last alpha was hard coded to alpha7, instead of being dynamic relative to m. This would have introduced a conflict when m was high enough.
This commit is contained in:
7
model.py
7
model.py
@@ -77,6 +77,8 @@ class FSRCNN(object):
|
||||
deconv_bias: tf.Variable(tf.zeros([1]), name=deconv_bias)
|
||||
}
|
||||
|
||||
self.alphas = {}
|
||||
|
||||
# Create the m mapping layers weights/biases
|
||||
for i in range(3, m + 3):
|
||||
weight_name, bias_name = 'w{}'.format(i), 'b{}'.format(i)
|
||||
@@ -102,7 +104,7 @@ class FSRCNN(object):
|
||||
print(" [!] Load failed...")
|
||||
|
||||
if self.params:
|
||||
save_params(self.sess, self.weights, self.biases)
|
||||
save_params(self.sess, self.weights, self.biases, self.alphas)
|
||||
elif self.train:
|
||||
self.run_train()
|
||||
else:
|
||||
@@ -197,7 +199,7 @@ class FSRCNN(object):
|
||||
|
||||
# Expanding
|
||||
expand_weights, expand_biases = self.weights['w{}'.format(m + 3)], self.biases['b{}'.format(m + 3)]
|
||||
conv_expand = self.prelu(tf.nn.conv2d(prev_layer, expand_weights, strides=[1,1,1,1], padding='SAME') + expand_biases, 7)
|
||||
conv_expand = self.prelu(tf.nn.conv2d(prev_layer, expand_weights, strides=[1,1,1,1], padding='SAME') + expand_biases, m + 3)
|
||||
|
||||
# Deconvolution
|
||||
deconv_output = [self.batch, self.label_size, self.label_size, self.c_dim]
|
||||
@@ -212,6 +214,7 @@ class FSRCNN(object):
|
||||
PreLU tensorflow implementation
|
||||
"""
|
||||
alphas = tf.get_variable('alpha{}'.format(i), _x.get_shape()[-1], initializer=tf.constant_initializer(0.0), dtype=tf.float32)
|
||||
self.alphas['alpha{}'.format(i)] = alphas
|
||||
pos = tf.nn.relu(_x)
|
||||
neg = alphas * (_x - abs(_x)) * 0.5
|
||||
|
||||
|
||||
12
utils.py
12
utils.py
@@ -285,7 +285,7 @@ def test_input_setup(config):
|
||||
return nx, ny
|
||||
|
||||
# You can ignore, I just wanted to see how much space all the parameters would take up
|
||||
def save_params(sess, weights, biases):
|
||||
def save_params(sess, weights, biases, alphas):
|
||||
param_dir = "params/"
|
||||
|
||||
if not os.path.exists(param_dir):
|
||||
@@ -319,6 +319,16 @@ def save_params(sess, weights, biases):
|
||||
|
||||
bias_file.close()
|
||||
|
||||
alpha_file = open(param_dir + "alpha.txt", 'w')
|
||||
for layer in alphas:
|
||||
alpha_file.write("Layer {}\n".format(layer))
|
||||
layer_alphas = sess.run(alphas[layer])
|
||||
for alpha in layer_alphas:
|
||||
alpha_file.write("{}, ".format(alpha))
|
||||
alpha_file.write("\n\n")
|
||||
|
||||
alpha_file.close()
|
||||
|
||||
def merge(images, size):
|
||||
"""
|
||||
Merges sub-images back into original image size
|
||||
|
||||
Reference in New Issue
Block a user