Minor changes to match trained SOTA model.

PiperOrigin-RevId: 450716647
This commit is contained in:
Kyle Taylor
2022-05-24 18:36:03 +01:00
committed by alimuldal
parent d436681054
commit 57456a001d
3 changed files with 19 additions and 11 deletions

View File

@@ -453,7 +453,8 @@ def positional_features_gamma(positions: tf.Tensor,
tf.abs(tf.cast(positions, dtype=tf.float32))[..., tf.newaxis],
concentration, rate)
probabilities += 1e-8 # To ensure numerical stability.
outputs = probabilities / tf.reduce_max(probabilities)
outputs = probabilities / tf.reduce_max(probabilities,
axis=1, keepdims=True)
tf.TensorShape(outputs.shape).assert_is_compatible_with(
positions.shape + [feature_size])
return outputs

View File

@@ -87,10 +87,12 @@ class Enformer(snt.Module):
# lambda is used in Sequential to construct the module under tf.name_scope.
def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs):
return Sequential(lambda: [
snt.BatchNorm(create_scale=True,
create_offset=True,
decay_rate=0.9,
scale_init=snt.initializers.Ones()),
snt.distribute.CrossReplicaBatchNorm(
create_scale=True,
create_offset=True,
scale_init=snt.initializers.Ones(),
moving_mean=snt.ExponentialMovingAverage(0.9),
moving_variance=snt.ExponentialMovingAverage(0.9)),
gelu,
snt.Conv1D(filters, width, w_init=w_init, **kwargs)
], name=name)
@@ -184,16 +186,22 @@ class Enformer(snt.Module):
class TargetLengthCrop1D(snt.Module):
"""Crop sequence to match the desired target length."""
def __init__(self, target_length: int, name='target_length_crop'):
def __init__(self,
target_length: Optional[int],
name: str = 'target_length_crop'):
super().__init__(name=name)
self._target_length = target_length
def __call__(self, inputs):
if self._target_length is None:
return inputs
trim = (inputs.shape[-2] - self._target_length) // 2
if trim < 0:
raise ValueError('inputs longer than target length')
return inputs[..., trim:-trim, :]
elif trim == 0:
return inputs
else:
return inputs[..., trim:-trim, :]
class Sequential(snt.Module):
@@ -209,8 +217,7 @@ class Sequential(snt.Module):
else:
# layers wrapped in a lambda function to have a common namespace.
if hasattr(layers, '__call__'):
with tf.name_scope(name):
layers = layers()
layers = layers()
self._layers = [layer for layer in layers if layer is not None]
def __call__(self, inputs: tf.Tensor, is_training: bool, **kwargs):

View File

@@ -2,5 +2,5 @@ dm-sonnet==2.0.0
kipoiseq==0.5.2
numpy==1.19.5
pandas==1.2.3
tensorflow==2.4.1
tensorflow==2.5.0
tensorflow-hub==0.11.0