mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Initial commit for NFNets and AGC
PiperOrigin-RevId: 357022514
This commit is contained in:
committed by
Louise Deason
parent
abc9ba0635
commit
68a1754d29
+45
-1
@@ -50,6 +50,50 @@ nf_regnet_params = {
|
||||
'drop_rate': 0.5},
|
||||
}
|
||||
|
||||
nfnet_params = {}
|
||||
|
||||
|
||||
# F-series models
|
||||
nfnet_params.update(**{
|
||||
'F0': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [1, 2, 6, 3],
|
||||
'train_imsize': 192, 'test_imsize': 256,
|
||||
'RA_level': '405', 'drop_rate': 0.2},
|
||||
'F1': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [2, 4, 12, 6],
|
||||
'train_imsize': 224, 'test_imsize': 320,
|
||||
'RA_level': '410', 'drop_rate': 0.3},
|
||||
'F2': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [3, 6, 18, 9],
|
||||
'train_imsize': 256, 'test_imsize': 352,
|
||||
'RA_level': '410', 'drop_rate': 0.4},
|
||||
'F3': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [4, 8, 24, 12],
|
||||
'train_imsize': 320, 'test_imsize': 416,
|
||||
'RA_level': '415', 'drop_rate': 0.4},
|
||||
'F4': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [5, 10, 30, 15],
|
||||
'train_imsize': 384, 'test_imsize': 512,
|
||||
'RA_level': '415', 'drop_rate': 0.5},
|
||||
'F5': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [6, 12, 36, 18],
|
||||
'train_imsize': 416, 'test_imsize': 544,
|
||||
'RA_level': '415', 'drop_rate': 0.5},
|
||||
'F6': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [7, 14, 42, 21],
|
||||
'train_imsize': 448, 'test_imsize': 576,
|
||||
'RA_level': '415', 'drop_rate': 0.5},
|
||||
'F7': {
|
||||
'width': [256, 512, 1536, 1536], 'depth': [8, 16, 48, 24],
|
||||
'train_imsize': 480, 'test_imsize': 608,
|
||||
'RA_level': '415', 'drop_rate': 0.5},
|
||||
})
|
||||
|
||||
# Minor variants FN+, slightly wider
|
||||
nfnet_params.update(**{
|
||||
**{f'{key}+': {**nfnet_params[key], 'width': [384, 768, 2048, 2048],}
|
||||
for key in nfnet_params}
|
||||
})
|
||||
|
||||
# Nonlinearities with magic constants (gamma) baked in.
|
||||
# Note that not all nonlinearities will be stable, especially if they are
|
||||
@@ -122,7 +166,7 @@ def signal_metrics(x, i):
|
||||
|
||||
def count_conv_flops(in_ch, conv, h, w):
|
||||
"""For a conv layer with in_ch inputs, count the FLOPS."""
|
||||
# How many output floats are we producing?
|
||||
# How many outputs are we producing? Note this is wrong for VALID padding.
|
||||
output_shape = conv.output_channels * (h * w) / np.prod(conv.stride)
|
||||
# At each OHW location we do computation equal to (I//G) * kh * kw
|
||||
flop_per_loc = (in_ch / conv.feature_group_count)
|
||||
|
||||
Reference in New Issue
Block a user