From 57456a001d91e8bd661d6f235f708cec3ffc799e Mon Sep 17 00:00:00 2001 From: Kyle Taylor Date: Tue, 24 May 2022 18:36:03 +0100 Subject: [PATCH] Minor changes to match trained SOTA model. PiperOrigin-RevId: 450716647 --- enformer/attention_module.py | 3 ++- enformer/enformer.py | 25 ++++++++++++++++--------- enformer/requirements.txt | 2 +- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/enformer/attention_module.py b/enformer/attention_module.py index 25434e3..e966fb5 100644 --- a/enformer/attention_module.py +++ b/enformer/attention_module.py @@ -453,7 +453,8 @@ def positional_features_gamma(positions: tf.Tensor, tf.abs(tf.cast(positions, dtype=tf.float32))[..., tf.newaxis], concentration, rate) probabilities += 1e-8 # To ensure numerical stability. - outputs = probabilities / tf.reduce_max(probabilities) + outputs = probabilities / tf.reduce_max(probabilities, + axis=1, keepdims=True) tf.TensorShape(outputs.shape).assert_is_compatible_with( positions.shape + [feature_size]) return outputs diff --git a/enformer/enformer.py b/enformer/enformer.py index 7fa7b58..c3ecbd2 100644 --- a/enformer/enformer.py +++ b/enformer/enformer.py @@ -87,10 +87,12 @@ class Enformer(snt.Module): # lambda is used in Sequential to construct the module under tf.name_scope. def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs): return Sequential(lambda: [ - snt.BatchNorm(create_scale=True, - create_offset=True, - decay_rate=0.9, - scale_init=snt.initializers.Ones()), + snt.distribute.CrossReplicaBatchNorm( + create_scale=True, + create_offset=True, + scale_init=snt.initializers.Ones(), + moving_mean=snt.ExponentialMovingAverage(0.9), + moving_variance=snt.ExponentialMovingAverage(0.9)), gelu, snt.Conv1D(filters, width, w_init=w_init, **kwargs) ], name=name) @@ -184,16 +186,22 @@ class Enformer(snt.Module): class TargetLengthCrop1D(snt.Module): """Crop sequence to match the desired target length.""" - def __init__(self, target_length: int, name='target_length_crop'): + def __init__(self, + target_length: Optional[int], + name: str = 'target_length_crop'): super().__init__(name=name) self._target_length = target_length def __call__(self, inputs): + if self._target_length is None: + return inputs trim = (inputs.shape[-2] - self._target_length) // 2 if trim < 0: raise ValueError('inputs longer than target length') - - return inputs[..., trim:-trim, :] + elif trim == 0: + return inputs + else: + return inputs[..., trim:-trim, :] class Sequential(snt.Module): @@ -209,8 +217,7 @@ class Sequential(snt.Module): else: # layers wrapped in a lambda function to have a common namespace. if hasattr(layers, '__call__'): - with tf.name_scope(name): - layers = layers() + layers = layers() self._layers = [layer for layer in layers if layer is not None] def __call__(self, inputs: tf.Tensor, is_training: bool, **kwargs): diff --git a/enformer/requirements.txt b/enformer/requirements.txt index 38ecc6e..d2eb68e 100644 --- a/enformer/requirements.txt +++ b/enformer/requirements.txt @@ -2,5 +2,5 @@ dm-sonnet==2.0.0 kipoiseq==0.5.2 numpy==1.19.5 pandas==1.2.3 -tensorflow==2.4.1 +tensorflow==2.5.0 tensorflow-hub==0.11.0