mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
ba761289c1
PiperOrigin-RevId: 357692801
221 lines
8.8 KiB
Python
221 lines
8.8 KiB
Python
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Architecture definitions for different models."""
|
|
import haiku as hk
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import numpy as np
|
|
|
|
|
|
# Model settings for NF-RegNets
|
|
nf_regnet_params = {
|
|
'B0': {'width': [48, 104, 208, 440], 'depth': [1, 3, 6, 6],
|
|
'train_imsize': 192, 'test_imsize': 224,
|
|
'drop_rate': 0.2},
|
|
'B1': {'width': [48, 104, 208, 440], 'depth': [2, 4, 7, 7],
|
|
'train_imsize': 224, 'test_imsize': 256,
|
|
'drop_rate': 0.2},
|
|
'B2': {'width': [56, 112, 232, 488], 'depth': [2, 4, 8, 8],
|
|
'train_imsize': 240, 'test_imsize': 272,
|
|
'drop_rate': 0.3},
|
|
'B3': {'width': [56, 128, 248, 528], 'depth': [2, 5, 9, 9],
|
|
'train_imsize': 288, 'test_imsize': 320,
|
|
'drop_rate': 0.3},
|
|
'B4': {'width': [64, 144, 288, 616], 'depth': [2, 6, 11, 11],
|
|
'train_imsize': 320, 'test_imsize': 384,
|
|
'drop_rate': 0.4},
|
|
'B5': {'width': [80, 168, 336, 704], 'depth': [3, 7, 14, 14],
|
|
'train_imsize': 384, 'test_imsize': 456,
|
|
'drop_rate': 0.4},
|
|
'B6': {'width': [88, 184, 376, 792], 'depth': [3, 8, 16, 16],
|
|
'train_imsize': 448, 'test_imsize': 528,
|
|
'drop_rate': 0.5},
|
|
'B7': {'width': [96, 208, 416, 880], 'depth': [4, 10, 19, 19],
|
|
'train_imsize': 512, 'test_imsize': 600,
|
|
'drop_rate': 0.5},
|
|
'B8': {'width': [104, 232, 456, 968], 'depth': [4, 11, 22, 22],
|
|
'train_imsize': 600, 'test_imsize': 672,
|
|
'drop_rate': 0.5},
|
|
}
|
|
|
|
nfnet_params = {}
|
|
|
|
|
|
# F-series models
|
|
nfnet_params.update(**{
|
|
'F0': {
|
|
'width': [256, 512, 1536, 1536], 'depth': [1, 2, 6, 3],
|
|
'train_imsize': 192, 'test_imsize': 256,
|
|
'RA_level': '405', 'drop_rate': 0.2},
|
|
'F1': {
|
|
'width': [256, 512, 1536, 1536], 'depth': [2, 4, 12, 6],
|
|
'train_imsize': 224, 'test_imsize': 320,
|
|
'RA_level': '410', 'drop_rate': 0.3},
|
|
'F2': {
|
|
'width': [256, 512, 1536, 1536], 'depth': [3, 6, 18, 9],
|
|
'train_imsize': 256, 'test_imsize': 352,
|
|
'RA_level': '410', 'drop_rate': 0.4},
|
|
'F3': {
|
|
'width': [256, 512, 1536, 1536], 'depth': [4, 8, 24, 12],
|
|
'train_imsize': 320, 'test_imsize': 416,
|
|
'RA_level': '415', 'drop_rate': 0.4},
|
|
'F4': {
|
|
'width': [256, 512, 1536, 1536], 'depth': [5, 10, 30, 15],
|
|
'train_imsize': 384, 'test_imsize': 512,
|
|
'RA_level': '415', 'drop_rate': 0.5},
|
|
'F5': {
|
|
'width': [256, 512, 1536, 1536], 'depth': [6, 12, 36, 18],
|
|
'train_imsize': 416, 'test_imsize': 544,
|
|
'RA_level': '415', 'drop_rate': 0.5},
|
|
'F6': {
|
|
'width': [256, 512, 1536, 1536], 'depth': [7, 14, 42, 21],
|
|
'train_imsize': 448, 'test_imsize': 576,
|
|
'RA_level': '415', 'drop_rate': 0.5},
|
|
'F7': {
|
|
'width': [256, 512, 1536, 1536], 'depth': [8, 16, 48, 24],
|
|
'train_imsize': 480, 'test_imsize': 608,
|
|
'RA_level': '415', 'drop_rate': 0.5},
|
|
})
|
|
|
|
# Minor variants FN+, slightly wider
|
|
nfnet_params.update(**{
|
|
**{f'{key}+': {**nfnet_params[key], 'width': [384, 768, 2048, 2048],}
|
|
for key in nfnet_params}
|
|
})
|
|
|
|
# Nonlinearities with magic constants (gamma) baked in.
|
|
# Note that not all nonlinearities will be stable, especially if they are
|
|
# not perfectly monotonic. Good choices include relu, silu, and gelu.
|
|
nonlinearities = {
|
|
'identity': lambda x: x,
|
|
'celu': lambda x: jax.nn.celu(x) * 1.270926833152771,
|
|
'elu': lambda x: jax.nn.elu(x) * 1.2716004848480225,
|
|
'gelu': lambda x: jax.nn.gelu(x) * 1.7015043497085571,
|
|
'glu': lambda x: jax.nn.glu(x) * 1.8484294414520264,
|
|
'leaky_relu': lambda x: jax.nn.leaky_relu(x) * 1.70590341091156,
|
|
'log_sigmoid': lambda x: jax.nn.log_sigmoid(x) * 1.9193484783172607,
|
|
'log_softmax': lambda x: jax.nn.log_softmax(x) * 1.0002083778381348,
|
|
'relu': lambda x: jax.nn.relu(x) * 1.7139588594436646,
|
|
'relu6': lambda x: jax.nn.relu6(x) * 1.7131484746932983,
|
|
'selu': lambda x: jax.nn.selu(x) * 1.0008515119552612,
|
|
'sigmoid': lambda x: jax.nn.sigmoid(x) * 4.803835391998291,
|
|
'silu': lambda x: jax.nn.silu(x) * 1.7881293296813965,
|
|
'soft_sign': lambda x: jax.nn.soft_sign(x) * 2.338853120803833,
|
|
'softplus': lambda x: jax.nn.softplus(x) * 1.9203323125839233,
|
|
'tanh': lambda x: jnp.tanh(x) * 1.5939117670059204,
|
|
}
|
|
|
|
|
|
class WSConv2D(hk.Conv2D):
|
|
"""2D Convolution with Scaled Weight Standardization and affine gain+bias."""
|
|
|
|
@hk.transparent
|
|
def standardize_weight(self, weight, eps=1e-4):
|
|
"""Apply scaled WS with affine gain."""
|
|
mean = jnp.mean(weight, axis=(0, 1, 2), keepdims=True)
|
|
var = jnp.var(weight, axis=(0, 1, 2), keepdims=True)
|
|
fan_in = np.prod(weight.shape[:-1])
|
|
# Get gain
|
|
gain = hk.get_parameter('gain', shape=(weight.shape[-1],),
|
|
dtype=weight.dtype, init=jnp.ones)
|
|
# Manually fused normalization, eq. to (w - mean) * gain / sqrt(N * var)
|
|
scale = jax.lax.rsqrt(jnp.maximum(var * fan_in, eps)) * gain
|
|
shift = mean * scale
|
|
return weight * scale - shift
|
|
|
|
def __call__(self, inputs: jnp.ndarray, eps: float = 1e-4) -> jnp.ndarray:
|
|
w_shape = self.kernel_shape + (
|
|
inputs.shape[self.channel_index] // self.feature_group_count,
|
|
self.output_channels)
|
|
# Use fan-in scaled init, but WS is largely insensitive to this choice.
|
|
w_init = hk.initializers.VarianceScaling(1.0, 'fan_in', 'normal')
|
|
w = hk.get_parameter('w', w_shape, inputs.dtype, init=w_init)
|
|
weight = self.standardize_weight(w, eps)
|
|
out = jax.lax.conv_general_dilated(
|
|
inputs, weight, window_strides=self.stride, padding=self.padding,
|
|
lhs_dilation=self.lhs_dilation, rhs_dilation=self.kernel_dilation,
|
|
dimension_numbers=self.dimension_numbers,
|
|
feature_group_count=self.feature_group_count)
|
|
# Always add bias
|
|
bias_shape = (self.output_channels,)
|
|
bias = hk.get_parameter('bias', bias_shape, inputs.dtype, init=jnp.zeros)
|
|
return out + bias
|
|
|
|
|
|
def signal_metrics(x, i):
|
|
"""Things to measure about a NCHW tensor activation."""
|
|
metrics = {}
|
|
# Average channel-wise mean-squared
|
|
metrics[f'avg_sq_mean_{i}'] = jnp.mean(jnp.mean(x, axis=[0, 1, 2])**2)
|
|
# Average channel variance
|
|
metrics[f'avg_var_{i}'] = jnp.mean(jnp.var(x, axis=[0, 1, 2]))
|
|
return metrics
|
|
|
|
|
|
def count_conv_flops(in_ch, conv, h, w):
|
|
"""For a conv layer with in_ch inputs, count the FLOPS."""
|
|
# How many outputs are we producing? Note this is wrong for VALID padding.
|
|
output_shape = conv.output_channels * (h * w) / np.prod(conv.stride)
|
|
# At each OHW location we do computation equal to (I//G) * kh * kw
|
|
flop_per_loc = (in_ch / conv.feature_group_count)
|
|
flop_per_loc *= np.prod(conv.kernel_shape)
|
|
return output_shape * flop_per_loc
|
|
|
|
|
|
class SqueezeExcite(hk.Module):
|
|
"""Simple Squeeze+Excite module."""
|
|
|
|
def __init__(self, in_ch, out_ch, se_ratio=0.5,
|
|
hidden_ch=None, activation=jax.nn.relu,
|
|
name=None):
|
|
super().__init__(name=name)
|
|
self.in_ch, self.out_ch = in_ch, out_ch
|
|
if se_ratio is None:
|
|
if hidden_ch is None:
|
|
raise ValueError('Must provide one of se_ratio or hidden_ch')
|
|
self.hidden_ch = hidden_ch
|
|
else:
|
|
self.hidden_ch = max(1, int(self.in_ch * se_ratio))
|
|
self.activation = activation
|
|
self.fc0 = hk.Linear(self.hidden_ch, with_bias=True)
|
|
self.fc1 = hk.Linear(self.out_ch, with_bias=True)
|
|
|
|
def __call__(self, x):
|
|
h = jnp.mean(x, axis=[1, 2]) # Mean pool over HW extent
|
|
h = self.fc1(self.activation(self.fc0(h)))
|
|
h = jax.nn.sigmoid(h)[:, None, None] # Broadcast along H, W
|
|
return h
|
|
|
|
|
|
class StochDepth(hk.Module):
|
|
"""Batchwise Dropout used in EfficientNet, optionally sans rescaling."""
|
|
|
|
def __init__(self, drop_rate, scale_by_keep=False, name=None):
|
|
super().__init__(name=name)
|
|
self.drop_rate = drop_rate
|
|
self.scale_by_keep = scale_by_keep
|
|
|
|
def __call__(self, x, is_training) -> jnp.ndarray:
|
|
if not is_training:
|
|
return x
|
|
batch_size = x.shape[0]
|
|
r = jax.random.uniform(hk.next_rng_key(), [batch_size, 1, 1, 1],
|
|
dtype=x.dtype)
|
|
keep_prob = 1. - self.drop_rate
|
|
binary_tensor = jnp.floor(keep_prob + r)
|
|
if self.scale_by_keep:
|
|
x = x / keep_prob
|
|
return x * binary_tensor
|