Initial commit for NFNets and AGC

PiperOrigin-RevId: 357022514
This commit is contained in:
Andy Brock
2021-02-11 19:38:33 +00:00
committed by Louise Deason
parent abc9ba0635
commit 68a1754d29
7 changed files with 646 additions and 20 deletions
+45 -1
View File
@@ -50,6 +50,50 @@ nf_regnet_params = {
'drop_rate': 0.5},
}
nfnet_params = {}
# F-series models
nfnet_params.update(**{
'F0': {
'width': [256, 512, 1536, 1536], 'depth': [1, 2, 6, 3],
'train_imsize': 192, 'test_imsize': 256,
'RA_level': '405', 'drop_rate': 0.2},
'F1': {
'width': [256, 512, 1536, 1536], 'depth': [2, 4, 12, 6],
'train_imsize': 224, 'test_imsize': 320,
'RA_level': '410', 'drop_rate': 0.3},
'F2': {
'width': [256, 512, 1536, 1536], 'depth': [3, 6, 18, 9],
'train_imsize': 256, 'test_imsize': 352,
'RA_level': '410', 'drop_rate': 0.4},
'F3': {
'width': [256, 512, 1536, 1536], 'depth': [4, 8, 24, 12],
'train_imsize': 320, 'test_imsize': 416,
'RA_level': '415', 'drop_rate': 0.4},
'F4': {
'width': [256, 512, 1536, 1536], 'depth': [5, 10, 30, 15],
'train_imsize': 384, 'test_imsize': 512,
'RA_level': '415', 'drop_rate': 0.5},
'F5': {
'width': [256, 512, 1536, 1536], 'depth': [6, 12, 36, 18],
'train_imsize': 416, 'test_imsize': 544,
'RA_level': '415', 'drop_rate': 0.5},
'F6': {
'width': [256, 512, 1536, 1536], 'depth': [7, 14, 42, 21],
'train_imsize': 448, 'test_imsize': 576,
'RA_level': '415', 'drop_rate': 0.5},
'F7': {
'width': [256, 512, 1536, 1536], 'depth': [8, 16, 48, 24],
'train_imsize': 480, 'test_imsize': 608,
'RA_level': '415', 'drop_rate': 0.5},
})
# Minor variants FN+, slightly wider
nfnet_params.update(**{
**{f'{key}+': {**nfnet_params[key], 'width': [384, 768, 2048, 2048],}
for key in nfnet_params}
})
# Nonlinearities with magic constants (gamma) baked in.
# Note that not all nonlinearities will be stable, especially if they are
@@ -122,7 +166,7 @@ def signal_metrics(x, i):
def count_conv_flops(in_ch, conv, h, w):
"""For a conv layer with in_ch inputs, count the FLOPS."""
# How many output floats are we producing?
# How many outputs are we producing? Note this is wrong for VALID padding.
output_shape = conv.output_channels * (h * w) / np.prod(conv.stride)
# At each OHW location we do computation equal to (I//G) * kh * kw
flop_per_loc = (in_ch / conv.feature_group_count)