diff --git a/model.py b/model.py index d17598f..857a2b6 100644 --- a/model.py +++ b/model.py @@ -48,6 +48,8 @@ class FSRCNN(object): # Different model layer counts and filter sizes for FSRCNN vs FSRCNN-s (fast), (d, s, m) in paper model_params = [[56, 12, 4], [32, 8, 1]] self.model_params = model_params[self.fast] + + self.deconv_radius = [3, 5, 7][self.scale - 2] self.checkpoint_dir = config.checkpoint_dir self.output_dir = config.output_dir @@ -65,11 +67,12 @@ class FSRCNN(object): d, s, m = self.model_params expand_weight, deconv_weight = 'w{}'.format(m + 3), 'w{}'.format(m + 4) + deconv_size = self.deconv_radius * 2 + 1 self.weights = { 'w1': tf.Variable(tf.random_normal([5, 5, 1, d], stddev=0.0378, dtype=tf.float32), name='w1'), 'w2': tf.Variable(tf.random_normal([1, 1, d, s], stddev=0.3536, dtype=tf.float32), name='w2'), expand_weight: tf.Variable(tf.random_normal([1, 1, s, d], stddev=0.189, dtype=tf.float32), name=expand_weight), - deconv_weight: tf.Variable(tf.random_normal([9, 9, 1, d], stddev=0.0001, dtype=tf.float32), name=deconv_weight) + deconv_weight: tf.Variable(tf.random_normal([deconv_size, deconv_size, 1, d], stddev=0.0001, dtype=tf.float32), name=deconv_weight) } expand_bias, deconv_bias = 'b{}'.format(m + 3), 'b{}'.format(m + 4)