diff --git a/FSRCNN.py b/FSRCNN.py index f12435f..93acbac 100644 --- a/FSRCNN.py +++ b/FSRCNN.py @@ -72,7 +72,8 @@ class Model(object): deconv_biases = tf.get_variable('deconv_b', initializer=tf.zeros([self.scale**2])) deconv = tf.nn.conv2d(conv, deconv_weights, strides=[1,1,1,1], padding='SAME', data_format='NHWC') deconv = tf.nn.bias_add(deconv, deconv_biases, data_format='NHWC') - deconv = tf.depth_to_space(deconv, self.scale, name='pixel_shuffle', data_format='NHWC') + if self.scale > 1: + deconv = tf.depth_to_space(deconv, self.scale, name='pixel_shuffle', data_format='NHWC') return deconv diff --git a/gen.py b/gen.py index 9b1b5cd..f6c538b 100644 --- a/gen.py +++ b/gen.py @@ -28,7 +28,8 @@ def format_weights(weights, n, length=4): def base_header(file): file.write('//!HOOK LUMA\n') - file.write('//!WHEN OUTPUT.w LUMA.w / {0}.400 > OUTPUT.h LUMA.h / {0}.400 > *\n'.format(scale - 1)) + if scale > 1: + file.write('//!WHEN OUTPUT.w LUMA.w / {0}.400 > OUTPUT.h LUMA.h / {0}.400 > *\n'.format(scale - 1)) def header1(file, n, d): base_header(file) @@ -75,8 +76,9 @@ def header5(file, n, d, inp): file.write('//!DESC sub-pixel convolution {}\n'.format((n//comps) + 1)) for i in range(d//4): file.write('//!BIND {}{}\n'.format(inp, i + 1)) - file.write('//!SAVE SUBCONV{}\n'.format((n//comps) + 1)) - file.write('//!COMPONENTS {}\n'.format(comps)) + if scale > 1: + file.write('//!SAVE SUBCONV{}\n'.format((n//comps) + 1)) + file.write('//!COMPONENTS {}\n'.format(comps)) def header6(file): base_header(file) @@ -219,45 +221,47 @@ def main(): ln = get_line_number("deconv_b", fname) biases = read_weights(fname, ln) inp = "EXPANDED" if shrinking else "RES" - comps = 3 if scale == 3 else 4 + comps = scale if scale % 2 == 1 else 4 for n in range(0, scale**2, comps): header5(file, n, d, inp) file.write('vec4 hook()\n') file.write('{\n') - file.write('vec{0} res = vec{0}({1});\n'.format(comps, format_weights(biases[0], n, length=comps))) + if scale == 1: + file.write('float res = {};\n'.format(format_weights(biases[0], n, length=comps))) + else: + file.write('vec{0} res = vec{0}({1});\n'.format(comps, format_weights(biases[0], n, length=comps))) p = 0 for l in range(0, len(weights), 4): if l % d == 0: y, x = p%(radius*2+1)-radius, p//(radius*2+1)-radius p += 1 idx = (l//4)%(d//4) - file.write('res += mat4x{}({},{},{},{}) * {}{}_texOff(vec2({},{}));\n'.format( - comps, format_weights(weights[l], n, length=comps), format_weights(weights[l+1], n, length=comps), + file.write('res += {}{}({},{},{},{}){} {}{}_texOff(vec2({},{})){};\n'.format( + "mat4x" if scale > 1 else "dot(", comps if scale > 1 else "vec4", + format_weights(weights[l], n, length=comps), format_weights(weights[l+1], n, length=comps), format_weights(weights[l+2], n, length=comps), format_weights(weights[l+3], n, length=comps), - inp, idx + 1, x, y)) - if comps == 4: - file.write('return res;\n') - else: - file.write('return vec4(res, 0);\n') + " *" if scale > 1 else ",", inp, idx + 1, x, y, "" if scale > 1 else ")")) + file.write('return vec4(res{});\n'.format(", 0" * (4 - comps))) file.write('}\n\n') - # Aggregation - header6(file) - file.write('vec4 hook()\n') - file.write('{\n') - file.write('vec2 fcoord = fract(SUBCONV1_pos * SUBCONV1_size);\n') - file.write('vec2 base = SUBCONV1_pos + (vec2(0.5) - fcoord) * SUBCONV1_pt;\n') - file.write('ivec2 index = ivec2(fcoord * vec2({}));\n'.format(scale)) - if scale > 2: - file.write('mat{0} res = mat{0}(SUBCONV1_tex(base).{1}'.format(scale, "rgba"[:comps])) - for i in range(scale-1): - file.write(',SUBCONV{}_tex(base).{}'.format(i + 2, "rgba"[:comps])) - file.write(');\n') - file.write('return vec4(res[index.x][index.y], 0, 0, 1);\n') - else: - file.write('vec4 res = SUBCONV1_tex(base);\n') - file.write('return vec4(res[index.x * {} + index.y], 0, 0, 1);\n'.format(scale)) - file.write('}\n') + if scale > 1: + # Aggregation + header6(file) + file.write('vec4 hook()\n') + file.write('{\n') + file.write('vec2 fcoord = fract(SUBCONV1_pos * SUBCONV1_size);\n') + file.write('vec2 base = SUBCONV1_pos + (vec2(0.5) - fcoord) * SUBCONV1_pt;\n') + file.write('ivec2 index = ivec2(fcoord * vec2({}));\n'.format(scale)) + if scale > 2: + file.write('mat{0} res = mat{0}(SUBCONV1_tex(base).{1}'.format(scale, "rgba"[:comps])) + for i in range(scale-1): + file.write(',SUBCONV{}_tex(base).{}'.format(i + 2, "rgba"[:comps])) + file.write(');\n') + file.write('return vec4(res[index.x][index.y], 0, 0, 1);\n') + else: + file.write('vec4 res = SUBCONV1_tex(base);\n') + file.write('return vec4(res[index.x * {} + index.y], 0, 0, 1);\n'.format(scale)) + file.write('}\n') else: print("Missing argument: You must specify a file name") diff --git a/model.py b/model.py index 5db5488..7cfb375 100644 --- a/model.py +++ b/model.py @@ -37,8 +37,8 @@ class Model(object): self.padding = 4 # Different image/label sub-sizes for different scaling factors x2, x3, x4 - scale_factors = [[20 + self.padding, 40], [14 + self.padding, 42], [12 + self.padding, 48]] - self.image_size, self.label_size = scale_factors[self.scale - 2] + scale_factors = [[40 + self.padding, 40], [20 + self.padding, 40], [14 + self.padding, 42], [12 + self.padding, 48]] + self.image_size, self.label_size = scale_factors[self.scale - 1] self.stride = self.image_size - self.padding