Fix initial convolution channels not multiplied by width_multiplier.

PiperOrigin-RevId: 327005191
This commit is contained in:
Florent Altché
2020-08-17 12:51:53 +00:00
committed by Louise Deason
parent 923ad3cff0
commit 63fa5e72d5
+49 -11
View File
@@ -45,7 +45,12 @@ class MLP(hk.Module):
return out
class ResNetTorso(hk.nets.ResNet):
def check_length(length, value, name):
if len(value) != length:
raise ValueError(f'`{name}` must be of length 4 not {len(value)}')
class ResNetTorso(hk.Module):
"""ResNet model."""
def __init__(
@@ -80,16 +85,49 @@ class ResNetTorso(hk.nets.ResNet):
width_multiplier: An integer multiplying the number of channels per group.
name: Name of the module.
"""
channels_per_group = [width_multiplier * channel
for channel in channels_per_group]
super().__init__(
blocks_per_group=blocks_per_group,
num_classes=num_classes,
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=bottleneck,
channels_per_group=channels_per_group,
name=name)
super().__init__(name=name)
self.resnet_v2 = resnet_v2
bn_config = dict(bn_config or {})
bn_config.setdefault('decay_rate', 0.9)
bn_config.setdefault('eps', 1e-5)
bn_config.setdefault('create_scale', True)
bn_config.setdefault('create_offset', True)
# Number of blocks in each group for ResNet.
check_length(4, blocks_per_group, 'blocks_per_group')
check_length(4, channels_per_group, 'channels_per_group')
self.initial_conv = hk.Conv2D(
output_channels=64 * width_multiplier,
kernel_shape=7,
stride=2,
with_bias=False,
padding='SAME',
name='initial_conv')
if not self.resnet_v2:
self.initial_batchnorm = hk.BatchNorm(name='initial_batchnorm',
**bn_config)
self.block_groups = []
strides = (1, 2, 2, 2)
for i in range(4):
self.block_groups.append(
hk.nets.ResNet.BlockGroup(
channels=width_multiplier * channels_per_group[i],
num_blocks=blocks_per_group[i],
stride=strides[i],
bn_config=bn_config,
resnet_v2=resnet_v2,
bottleneck=bottleneck,
use_projection=use_projection[i],
name='block_group_%d' % (i)))
if self.resnet_v2:
self.final_batchnorm = hk.BatchNorm(name='final_batchnorm', **bn_config)
self.logits = hk.Linear(num_classes, w_init=jnp.zeros, name='logits')
def __call__(self, inputs, is_training, test_local_stats=False):
out = inputs