mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-10 05:17:46 +08:00
Fix initial convolution channels not multiplied by width_multiplier.
PiperOrigin-RevId: 327005191
This commit is contained in:
committed by
Louise Deason
parent
923ad3cff0
commit
63fa5e72d5
+49
-11
@@ -45,7 +45,12 @@ class MLP(hk.Module):
|
||||
return out
|
||||
|
||||
|
||||
class ResNetTorso(hk.nets.ResNet):
|
||||
def check_length(length, value, name):
|
||||
if len(value) != length:
|
||||
raise ValueError(f'`{name}` must be of length 4 not {len(value)}')
|
||||
|
||||
|
||||
class ResNetTorso(hk.Module):
|
||||
"""ResNet model."""
|
||||
|
||||
def __init__(
|
||||
@@ -80,16 +85,49 @@ class ResNetTorso(hk.nets.ResNet):
|
||||
width_multiplier: An integer multiplying the number of channels per group.
|
||||
name: Name of the module.
|
||||
"""
|
||||
channels_per_group = [width_multiplier * channel
|
||||
for channel in channels_per_group]
|
||||
super().__init__(
|
||||
blocks_per_group=blocks_per_group,
|
||||
num_classes=num_classes,
|
||||
bn_config=bn_config,
|
||||
resnet_v2=resnet_v2,
|
||||
bottleneck=bottleneck,
|
||||
channels_per_group=channels_per_group,
|
||||
name=name)
|
||||
super().__init__(name=name)
|
||||
self.resnet_v2 = resnet_v2
|
||||
|
||||
bn_config = dict(bn_config or {})
|
||||
bn_config.setdefault('decay_rate', 0.9)
|
||||
bn_config.setdefault('eps', 1e-5)
|
||||
bn_config.setdefault('create_scale', True)
|
||||
bn_config.setdefault('create_offset', True)
|
||||
|
||||
# Number of blocks in each group for ResNet.
|
||||
check_length(4, blocks_per_group, 'blocks_per_group')
|
||||
check_length(4, channels_per_group, 'channels_per_group')
|
||||
|
||||
self.initial_conv = hk.Conv2D(
|
||||
output_channels=64 * width_multiplier,
|
||||
kernel_shape=7,
|
||||
stride=2,
|
||||
with_bias=False,
|
||||
padding='SAME',
|
||||
name='initial_conv')
|
||||
|
||||
if not self.resnet_v2:
|
||||
self.initial_batchnorm = hk.BatchNorm(name='initial_batchnorm',
|
||||
**bn_config)
|
||||
|
||||
self.block_groups = []
|
||||
strides = (1, 2, 2, 2)
|
||||
for i in range(4):
|
||||
self.block_groups.append(
|
||||
hk.nets.ResNet.BlockGroup(
|
||||
channels=width_multiplier * channels_per_group[i],
|
||||
num_blocks=blocks_per_group[i],
|
||||
stride=strides[i],
|
||||
bn_config=bn_config,
|
||||
resnet_v2=resnet_v2,
|
||||
bottleneck=bottleneck,
|
||||
use_projection=use_projection[i],
|
||||
name='block_group_%d' % (i)))
|
||||
|
||||
if self.resnet_v2:
|
||||
self.final_batchnorm = hk.BatchNorm(name='final_batchnorm', **bn_config)
|
||||
|
||||
self.logits = hk.Linear(num_classes, w_init=jnp.zeros, name='logits')
|
||||
|
||||
def __call__(self, inputs, is_training, test_local_stats=False):
|
||||
out = inputs
|
||||
|
||||
Reference in New Issue
Block a user