diff --git a/byol/utils/networks.py b/byol/utils/networks.py index 72c8be7..3bd10e6 100644 --- a/byol/utils/networks.py +++ b/byol/utils/networks.py @@ -45,7 +45,12 @@ class MLP(hk.Module): return out -class ResNetTorso(hk.nets.ResNet): +def check_length(length, value, name): + if len(value) != length: + raise ValueError(f'`{name}` must be of length 4 not {len(value)}') + + +class ResNetTorso(hk.Module): """ResNet model.""" def __init__( @@ -80,16 +85,49 @@ class ResNetTorso(hk.nets.ResNet): width_multiplier: An integer multiplying the number of channels per group. name: Name of the module. """ - channels_per_group = [width_multiplier * channel - for channel in channels_per_group] - super().__init__( - blocks_per_group=blocks_per_group, - num_classes=num_classes, - bn_config=bn_config, - resnet_v2=resnet_v2, - bottleneck=bottleneck, - channels_per_group=channels_per_group, - name=name) + super().__init__(name=name) + self.resnet_v2 = resnet_v2 + + bn_config = dict(bn_config or {}) + bn_config.setdefault('decay_rate', 0.9) + bn_config.setdefault('eps', 1e-5) + bn_config.setdefault('create_scale', True) + bn_config.setdefault('create_offset', True) + + # Number of blocks in each group for ResNet. + check_length(4, blocks_per_group, 'blocks_per_group') + check_length(4, channels_per_group, 'channels_per_group') + + self.initial_conv = hk.Conv2D( + output_channels=64 * width_multiplier, + kernel_shape=7, + stride=2, + with_bias=False, + padding='SAME', + name='initial_conv') + + if not self.resnet_v2: + self.initial_batchnorm = hk.BatchNorm(name='initial_batchnorm', + **bn_config) + + self.block_groups = [] + strides = (1, 2, 2, 2) + for i in range(4): + self.block_groups.append( + hk.nets.ResNet.BlockGroup( + channels=width_multiplier * channels_per_group[i], + num_blocks=blocks_per_group[i], + stride=strides[i], + bn_config=bn_config, + resnet_v2=resnet_v2, + bottleneck=bottleneck, + use_projection=use_projection[i], + name='block_group_%d' % (i))) + + if self.resnet_v2: + self.final_batchnorm = hk.BatchNorm(name='final_batchnorm', **bn_config) + + self.logits = hk.Linear(num_classes, w_init=jnp.zeros, name='logits') def __call__(self, inputs, is_training, test_local_stats=False): out = inputs