mirror of
https://github.com/igv/FSRCNN-TensorFlow.git
synced 2025-12-14 07:24:05 +08:00
Make it possible to disable shrinking/expanding layers
This commit is contained in:
72
model.py
72
model.py
@@ -54,43 +54,16 @@ class FSRCNN(object):
|
||||
self.checkpoint_dir = config.checkpoint_dir
|
||||
self.output_dir = config.output_dir
|
||||
self.data_dir = config.data_dir
|
||||
self.build_model()
|
||||
self.init_model()
|
||||
|
||||
|
||||
def build_model(self):
|
||||
def init_model(self):
|
||||
self.images = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.c_dim], name='images')
|
||||
self.labels = tf.placeholder(tf.float32, [None, self.label_size, self.label_size, self.c_dim], name='labels')
|
||||
# Batch size differs in training vs testing
|
||||
self.batch = tf.placeholder(tf.int32, shape=[], name='batch')
|
||||
|
||||
# FSCRNN-s (fast) has smaller filters and less layers but can achieve faster performance
|
||||
d, s, m = self.model_params
|
||||
|
||||
expand_weight, deconv_weight = 'w{}'.format(m + 3), 'w{}'.format(m + 4)
|
||||
deconv_size = self.deconv_radius * 2 + 1
|
||||
self.weights = {
|
||||
'w1': tf.Variable(tf.random_normal([5, 5, 1, d], stddev=0.0378, dtype=tf.float32), name='w1'),
|
||||
'w2': tf.Variable(tf.random_normal([1, 1, d, s], stddev=0.3536, dtype=tf.float32), name='w2'),
|
||||
expand_weight: tf.Variable(tf.random_normal([1, 1, s, d], stddev=0.189, dtype=tf.float32), name=expand_weight),
|
||||
deconv_weight: tf.Variable(tf.random_normal([deconv_size, deconv_size, 1, d], stddev=0.0001, dtype=tf.float32), name=deconv_weight)
|
||||
}
|
||||
|
||||
expand_bias, deconv_bias = 'b{}'.format(m + 3), 'b{}'.format(m + 4)
|
||||
self.biases = {
|
||||
'b1': tf.Variable(tf.zeros([d]), name='b1'),
|
||||
'b2': tf.Variable(tf.zeros([s]), name='b2'),
|
||||
expand_bias: tf.Variable(tf.zeros([d]), name=expand_bias),
|
||||
deconv_bias: tf.Variable(tf.zeros([1]), name=deconv_bias)
|
||||
}
|
||||
|
||||
self.alphas = {}
|
||||
|
||||
# Create the m mapping layers weights/biases
|
||||
for i in range(3, m + 3):
|
||||
weight_name, bias_name = 'w{}'.format(i), 'b{}'.format(i)
|
||||
self.weights[weight_name] = tf.Variable(tf.random_normal([3, 3, s, s], stddev=0.1179, dtype=tf.float32), name=weight_name)
|
||||
self.biases[bias_name] = tf.Variable(tf.zeros([s]), name=bias_name)
|
||||
|
||||
self.weights, self.biases, self.alphas = {}, {}, {}
|
||||
|
||||
self.pred = self.model()
|
||||
|
||||
# Loss function (structural dissimilarity)
|
||||
@@ -192,29 +165,46 @@ class FSRCNN(object):
|
||||
array_image_save(result * 255, image_path)
|
||||
|
||||
def model(self):
|
||||
|
||||
d, s, m = self.model_params
|
||||
|
||||
# Feature Extraction
|
||||
conv_feature = self.prelu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'], 1)
|
||||
self.weights['w1'] = tf.get_variable('w1', initializer=tf.random_normal([5, 5, 1, d], stddev=0.0378, dtype=tf.float32))
|
||||
self.biases['b1'] = tf.get_variable('b1', initializer=tf.zeros([d]))
|
||||
conv = self.prelu(tf.nn.conv2d(self.images, self.weights['w1'], strides=[1,1,1,1], padding='VALID') + self.biases['b1'], 1)
|
||||
|
||||
# Shrinking
|
||||
conv_shrink = self.prelu(tf.nn.conv2d(conv_feature, self.weights['w2'], strides=[1,1,1,1], padding='SAME') + self.biases['b2'], 2)
|
||||
if self.model_params[1] > 0:
|
||||
self.weights['w2'] = tf.get_variable('w2', initializer=tf.random_normal([1, 1, d, s], stddev=0.3536, dtype=tf.float32))
|
||||
self.biases['b2'] = tf.get_variable('b2', initializer=tf.zeros([s]))
|
||||
conv = self.prelu(tf.nn.conv2d(conv, self.weights['w2'], strides=[1,1,1,1], padding='SAME') + self.biases['b2'], 2)
|
||||
|
||||
# Mapping (# mapping layers = m)
|
||||
prev_layer, m = conv_shrink, self.model_params[2]
|
||||
if s == 0:
|
||||
s = d
|
||||
for i in range(3, m + 3):
|
||||
weights, biases = self.weights['w{}'.format(i)], self.biases['b{}'.format(i)]
|
||||
prev_layer = self.prelu(tf.nn.conv2d(prev_layer, weights, strides=[1,1,1,1], padding='SAME') + biases, i)
|
||||
weights = tf.get_variable('w{}'.format(i), initializer=tf.random_normal([3, 3, s, s], stddev=0.1179, dtype=tf.float32))
|
||||
biases = tf.get_variable('b{}'.format(i), initializer=tf.zeros([s]))
|
||||
self.weights['w{}'.format(i)], self.biases['b{}'.format(i)] = weights, biases
|
||||
conv = self.prelu(tf.nn.conv2d(conv, weights, strides=[1,1,1,1], padding='SAME') + biases, i)
|
||||
|
||||
# Expanding
|
||||
expand_weights, expand_biases = self.weights['w{}'.format(m + 3)], self.biases['b{}'.format(m + 3)]
|
||||
conv_expand = self.prelu(tf.nn.conv2d(prev_layer, expand_weights, strides=[1,1,1,1], padding='SAME') + expand_biases, m + 3)
|
||||
if self.model_params[1] > 0:
|
||||
expand_weights = tf.get_variable('w{}'.format(m + 3), initializer=tf.random_normal([1, 1, s, d], stddev=0.189, dtype=tf.float32))
|
||||
expand_biases = tf.get_variable('b{}'.format(m + 3), initializer=tf.zeros([d]))
|
||||
self.weights['w{}'.format(m + 3)], self.biases['b{}'.format(m + 3)] = expand_weights, expand_biases
|
||||
conv = self.prelu(tf.nn.conv2d(conv, expand_weights, strides=[1,1,1,1], padding='SAME') + expand_biases, m + 3)
|
||||
|
||||
# Deconvolution
|
||||
deconv_size = self.deconv_radius * 2 + 1
|
||||
deconv_weights = tf.get_variable('w{}'.format(m + 4), initializer=tf.random_normal([deconv_size, deconv_size, 1, d], stddev=0.0001, dtype=tf.float32))
|
||||
deconv_biases = tf.get_variable('b{}'.format(m + 4), initializer=tf.zeros([1]))
|
||||
self.weights['w{}'.format(m + 4)], self.biases['b{}'.format(m + 4)] = deconv_weights, deconv_biases
|
||||
deconv_output = [self.batch, self.label_size, self.label_size, self.c_dim]
|
||||
deconv_stride = [1, self.scale, self.scale, 1]
|
||||
deconv_weights, deconv_biases = self.weights['w{}'.format(m + 4)], self.biases['b{}'.format(m + 4)]
|
||||
conv_deconv = tf.nn.conv2d_transpose(conv_expand, deconv_weights, output_shape=deconv_output, strides=deconv_stride, padding='SAME') + deconv_biases
|
||||
deconv = tf.nn.conv2d_transpose(conv, deconv_weights, output_shape=deconv_output, strides=deconv_stride, padding='SAME') + deconv_biases
|
||||
|
||||
return conv_deconv
|
||||
return deconv
|
||||
|
||||
def prelu(self, _x, i):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user