mirror of
https://github.com/igv/FSRCNN-TensorFlow.git
synced 2026-02-06 15:11:56 +08:00
gen.py: support for scaling factors 3 and 4
This commit is contained in:
63
gen.py
63
gen.py
@@ -1,7 +1,7 @@
|
||||
import sys
|
||||
import math
|
||||
from itertools import islice
|
||||
|
||||
scale = 2
|
||||
radius = 1
|
||||
|
||||
def get_line_number(phrase, file_name):
|
||||
@@ -23,8 +23,8 @@ def read_weights(file_name, ln, size=1):
|
||||
|
||||
return [x.strip() for x in content]
|
||||
|
||||
def format_weights(weights, n):
|
||||
return ",".join(['{:.16f}'.format(float(i)) for i in weights.strip(",").split(",")[n:n+4]])
|
||||
def format_weights(weights, n, length=4):
|
||||
return ",".join(['{:.16f}'.format(float(i)) for i in weights.strip(",").split(",")[n:n+length]])
|
||||
|
||||
def base_header(file):
|
||||
file.write('//!HOOK LUMA\n')
|
||||
@@ -70,20 +70,21 @@ def header4(file, s, m, r, n, d):
|
||||
file.write('//!SAVE EXPANDED{}\n'.format((n//4)%(d//4) + 1))
|
||||
file.write('//!COMPONENTS 4\n')
|
||||
|
||||
def header5(file, d, inp):
|
||||
def header5(file, n, d, inp):
|
||||
base_header(file)
|
||||
file.write('//!DESC sub-pixel convolution\n')
|
||||
file.write('//!DESC sub-pixel convolution {}\n'.format((n//comps) + 1))
|
||||
for i in range(d//4):
|
||||
file.write('//!BIND {}{}\n'.format(inp, i + 1))
|
||||
file.write('//!SAVE {}1\n'.format(inp))
|
||||
file.write('//!COMPONENTS 4\n')
|
||||
file.write('//!SAVE SUBCONV{}\n'.format((n//comps) + 1))
|
||||
file.write('//!COMPONENTS {}\n'.format(comps))
|
||||
|
||||
def header6(file, inp):
|
||||
def header6(file):
|
||||
base_header(file)
|
||||
file.write('//!WIDTH LUMA.w {} *\n'.format(scale))
|
||||
file.write('//!HEIGHT LUMA.h {} *\n'.format(scale))
|
||||
file.write('//!DESC aggregation\n')
|
||||
file.write('//!BIND {}1\n'.format(inp))
|
||||
for i in range(scale**2//comps):
|
||||
file.write('//!BIND SUBCONV{}\n'.format(i + 1))
|
||||
|
||||
def main():
|
||||
if len(sys.argv) == 2:
|
||||
@@ -94,6 +95,9 @@ def main():
|
||||
shrinking = False
|
||||
else:
|
||||
shrinking = True
|
||||
global scale, comps
|
||||
deconv_biases = read_weights(fname, get_line_number("deconv_b", fname))
|
||||
scale = int(math.sqrt(len(deconv_biases[0].split(","))))
|
||||
dst = fname.replace("_", "-").replace("weights", "FSRCNNX_x{}_".format(scale)).replace("txt", "glsl")
|
||||
with open(dst, 'w') as file:
|
||||
|
||||
@@ -215,33 +219,44 @@ def main():
|
||||
ln = get_line_number("deconv_b", fname)
|
||||
biases = read_weights(fname, ln)
|
||||
inp = "EXPANDED" if shrinking else "RES"
|
||||
header5(file, d, inp)
|
||||
file.write('vec4 hook()\n')
|
||||
file.write('{\n')
|
||||
file.write('vec4 res = vec4({});\n'.format(biases[0]))
|
||||
for n in range(0, scale**2, 4):
|
||||
comps = 3 if scale == 3 else 4
|
||||
for n in range(0, scale**2, comps):
|
||||
header5(file, n, d, inp)
|
||||
file.write('vec4 hook()\n')
|
||||
file.write('{\n')
|
||||
file.write('vec{0} res = vec{0}({1});\n'.format(comps, format_weights(biases[0], n, length=comps)))
|
||||
p = 0
|
||||
for l in range(0, len(weights), 4):
|
||||
if l % d == 0:
|
||||
y, x = p%(radius*2+1)-radius, p//(radius*2+1)-radius
|
||||
p += 1
|
||||
idx = (l//4)%(d//4)
|
||||
file.write('res += mat4({},{},{},{}) * {}{}_texOff(vec2({},{}));\n'.format(
|
||||
format_weights(weights[l], n), format_weights(weights[l+1], n),
|
||||
format_weights(weights[l+2], n), format_weights(weights[l+3], n),
|
||||
file.write('res += mat4x{}({},{},{},{}) * {}{}_texOff(vec2({},{}));\n'.format(
|
||||
comps, format_weights(weights[l], n, length=comps), format_weights(weights[l+1], n, length=comps),
|
||||
format_weights(weights[l+2], n, length=comps), format_weights(weights[l+3], n, length=comps),
|
||||
inp, idx + 1, x, y))
|
||||
file.write('return res;\n')
|
||||
file.write('}\n\n')
|
||||
if comps == 4:
|
||||
file.write('return res;\n')
|
||||
else:
|
||||
file.write('return vec4(res, 0);\n')
|
||||
file.write('}\n\n')
|
||||
|
||||
# Aggregation
|
||||
header6(file, inp)
|
||||
header6(file)
|
||||
file.write('vec4 hook()\n')
|
||||
file.write('{\n')
|
||||
file.write('vec2 fcoord = fract({}1_pos * {}1_size);\n'.format(inp, inp))
|
||||
file.write('vec2 base = {}1_pos + (vec2(0.5) - fcoord) * {}1_pt;\n'.format(inp, inp))
|
||||
file.write('vec2 fcoord = fract(SUBCONV1_pos * SUBCONV1_size);\n')
|
||||
file.write('vec2 base = SUBCONV1_pos + (vec2(0.5) - fcoord) * SUBCONV1_pt;\n')
|
||||
file.write('ivec2 index = ivec2(fcoord * vec2({}));\n'.format(scale))
|
||||
file.write('float res = {}1_tex(base)[index.x * {} + index.y];\n'.format(inp, scale))
|
||||
file.write('return vec4(res, 0, 0, 1);\n')
|
||||
if scale > 2:
|
||||
file.write('mat{0} res = mat{0}(SUBCONV1_tex(base).{1}'.format(scale, "rgba"[:comps]))
|
||||
for i in range(scale-1):
|
||||
file.write(',SUBCONV{}_tex(base).{}'.format(i + 2, "rgba"[:comps]))
|
||||
file.write(');\n')
|
||||
file.write('return vec4(res[index.x][index.y], 0, 0, 1);\n')
|
||||
else:
|
||||
file.write('vec4 res = SUBCONV1_tex(base);\n')
|
||||
file.write('return vec4(res[index.x * {} + index.y], 0, 0, 1);\n'.format(scale))
|
||||
file.write('}\n')
|
||||
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user