mirror of
https://github.com/igv/FSRCNN-TensorFlow.git
synced 2025-12-16 01:24:28 +08:00
Adjust deconvolution filter size for different scaling factors
This commit is contained in:
5
model.py
5
model.py
@@ -48,6 +48,8 @@ class FSRCNN(object):
|
||||
# Different model layer counts and filter sizes for FSRCNN vs FSRCNN-s (fast), (d, s, m) in paper
|
||||
model_params = [[56, 12, 4], [32, 8, 1]]
|
||||
self.model_params = model_params[self.fast]
|
||||
|
||||
self.deconv_radius = [3, 5, 7][self.scale - 2]
|
||||
|
||||
self.checkpoint_dir = config.checkpoint_dir
|
||||
self.output_dir = config.output_dir
|
||||
@@ -65,11 +67,12 @@ class FSRCNN(object):
|
||||
d, s, m = self.model_params
|
||||
|
||||
expand_weight, deconv_weight = 'w{}'.format(m + 3), 'w{}'.format(m + 4)
|
||||
deconv_size = self.deconv_radius * 2 + 1
|
||||
self.weights = {
|
||||
'w1': tf.Variable(tf.random_normal([5, 5, 1, d], stddev=0.0378, dtype=tf.float32), name='w1'),
|
||||
'w2': tf.Variable(tf.random_normal([1, 1, d, s], stddev=0.3536, dtype=tf.float32), name='w2'),
|
||||
expand_weight: tf.Variable(tf.random_normal([1, 1, s, d], stddev=0.189, dtype=tf.float32), name=expand_weight),
|
||||
deconv_weight: tf.Variable(tf.random_normal([9, 9, 1, d], stddev=0.0001, dtype=tf.float32), name=deconv_weight)
|
||||
deconv_weight: tf.Variable(tf.random_normal([deconv_size, deconv_size, 1, d], stddev=0.0001, dtype=tf.float32), name=deconv_weight)
|
||||
}
|
||||
|
||||
expand_bias, deconv_bias = 'b{}'.format(m + 3), 'b{}'.format(m + 4)
|
||||
|
||||
Reference in New Issue
Block a user