mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-30 20:35:25 +08:00
Fix initial convolution channels not multiplied by width_multiplier.
PiperOrigin-RevId: 327005191
This commit is contained in:
committed by
Louise Deason
parent
923ad3cff0
commit
63fa5e72d5
+49
-11
@@ -45,7 +45,12 @@ class MLP(hk.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ResNetTorso(hk.nets.ResNet):
|
def check_length(length, value, name):
|
||||||
|
if len(value) != length:
|
||||||
|
raise ValueError(f'`{name}` must be of length 4 not {len(value)}')
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetTorso(hk.Module):
|
||||||
"""ResNet model."""
|
"""ResNet model."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -80,16 +85,49 @@ class ResNetTorso(hk.nets.ResNet):
|
|||||||
width_multiplier: An integer multiplying the number of channels per group.
|
width_multiplier: An integer multiplying the number of channels per group.
|
||||||
name: Name of the module.
|
name: Name of the module.
|
||||||
"""
|
"""
|
||||||
channels_per_group = [width_multiplier * channel
|
super().__init__(name=name)
|
||||||
for channel in channels_per_group]
|
self.resnet_v2 = resnet_v2
|
||||||
super().__init__(
|
|
||||||
blocks_per_group=blocks_per_group,
|
bn_config = dict(bn_config or {})
|
||||||
num_classes=num_classes,
|
bn_config.setdefault('decay_rate', 0.9)
|
||||||
bn_config=bn_config,
|
bn_config.setdefault('eps', 1e-5)
|
||||||
resnet_v2=resnet_v2,
|
bn_config.setdefault('create_scale', True)
|
||||||
bottleneck=bottleneck,
|
bn_config.setdefault('create_offset', True)
|
||||||
channels_per_group=channels_per_group,
|
|
||||||
name=name)
|
# Number of blocks in each group for ResNet.
|
||||||
|
check_length(4, blocks_per_group, 'blocks_per_group')
|
||||||
|
check_length(4, channels_per_group, 'channels_per_group')
|
||||||
|
|
||||||
|
self.initial_conv = hk.Conv2D(
|
||||||
|
output_channels=64 * width_multiplier,
|
||||||
|
kernel_shape=7,
|
||||||
|
stride=2,
|
||||||
|
with_bias=False,
|
||||||
|
padding='SAME',
|
||||||
|
name='initial_conv')
|
||||||
|
|
||||||
|
if not self.resnet_v2:
|
||||||
|
self.initial_batchnorm = hk.BatchNorm(name='initial_batchnorm',
|
||||||
|
**bn_config)
|
||||||
|
|
||||||
|
self.block_groups = []
|
||||||
|
strides = (1, 2, 2, 2)
|
||||||
|
for i in range(4):
|
||||||
|
self.block_groups.append(
|
||||||
|
hk.nets.ResNet.BlockGroup(
|
||||||
|
channels=width_multiplier * channels_per_group[i],
|
||||||
|
num_blocks=blocks_per_group[i],
|
||||||
|
stride=strides[i],
|
||||||
|
bn_config=bn_config,
|
||||||
|
resnet_v2=resnet_v2,
|
||||||
|
bottleneck=bottleneck,
|
||||||
|
use_projection=use_projection[i],
|
||||||
|
name='block_group_%d' % (i)))
|
||||||
|
|
||||||
|
if self.resnet_v2:
|
||||||
|
self.final_batchnorm = hk.BatchNorm(name='final_batchnorm', **bn_config)
|
||||||
|
|
||||||
|
self.logits = hk.Linear(num_classes, w_init=jnp.zeros, name='logits')
|
||||||
|
|
||||||
def __call__(self, inputs, is_training, test_local_stats=False):
|
def __call__(self, inputs, is_training, test_local_stats=False):
|
||||||
out = inputs
|
out = inputs
|
||||||
|
|||||||
Reference in New Issue
Block a user