mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-05 19:26:22 +08:00
Minor changes to match trained SOTA model.
PiperOrigin-RevId: 450716647
This commit is contained in:
@@ -453,7 +453,8 @@ def positional_features_gamma(positions: tf.Tensor,
|
||||
tf.abs(tf.cast(positions, dtype=tf.float32))[..., tf.newaxis],
|
||||
concentration, rate)
|
||||
probabilities += 1e-8 # To ensure numerical stability.
|
||||
outputs = probabilities / tf.reduce_max(probabilities)
|
||||
outputs = probabilities / tf.reduce_max(probabilities,
|
||||
axis=1, keepdims=True)
|
||||
tf.TensorShape(outputs.shape).assert_is_compatible_with(
|
||||
positions.shape + [feature_size])
|
||||
return outputs
|
||||
|
||||
@@ -87,10 +87,12 @@ class Enformer(snt.Module):
|
||||
# lambda is used in Sequential to construct the module under tf.name_scope.
|
||||
def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs):
|
||||
return Sequential(lambda: [
|
||||
snt.BatchNorm(create_scale=True,
|
||||
create_offset=True,
|
||||
decay_rate=0.9,
|
||||
scale_init=snt.initializers.Ones()),
|
||||
snt.distribute.CrossReplicaBatchNorm(
|
||||
create_scale=True,
|
||||
create_offset=True,
|
||||
scale_init=snt.initializers.Ones(),
|
||||
moving_mean=snt.ExponentialMovingAverage(0.9),
|
||||
moving_variance=snt.ExponentialMovingAverage(0.9)),
|
||||
gelu,
|
||||
snt.Conv1D(filters, width, w_init=w_init, **kwargs)
|
||||
], name=name)
|
||||
@@ -184,16 +186,22 @@ class Enformer(snt.Module):
|
||||
class TargetLengthCrop1D(snt.Module):
|
||||
"""Crop sequence to match the desired target length."""
|
||||
|
||||
def __init__(self, target_length: int, name='target_length_crop'):
|
||||
def __init__(self,
|
||||
target_length: Optional[int],
|
||||
name: str = 'target_length_crop'):
|
||||
super().__init__(name=name)
|
||||
self._target_length = target_length
|
||||
|
||||
def __call__(self, inputs):
|
||||
if self._target_length is None:
|
||||
return inputs
|
||||
trim = (inputs.shape[-2] - self._target_length) // 2
|
||||
if trim < 0:
|
||||
raise ValueError('inputs longer than target length')
|
||||
|
||||
return inputs[..., trim:-trim, :]
|
||||
elif trim == 0:
|
||||
return inputs
|
||||
else:
|
||||
return inputs[..., trim:-trim, :]
|
||||
|
||||
|
||||
class Sequential(snt.Module):
|
||||
@@ -209,8 +217,7 @@ class Sequential(snt.Module):
|
||||
else:
|
||||
# layers wrapped in a lambda function to have a common namespace.
|
||||
if hasattr(layers, '__call__'):
|
||||
with tf.name_scope(name):
|
||||
layers = layers()
|
||||
layers = layers()
|
||||
self._layers = [layer for layer in layers if layer is not None]
|
||||
|
||||
def __call__(self, inputs: tf.Tensor, is_training: bool, **kwargs):
|
||||
|
||||
@@ -2,5 +2,5 @@ dm-sonnet==2.0.0
|
||||
kipoiseq==0.5.2
|
||||
numpy==1.19.5
|
||||
pandas==1.2.3
|
||||
tensorflow==2.4.1
|
||||
tensorflow==2.5.0
|
||||
tensorflow-hub==0.11.0
|
||||
|
||||
Reference in New Issue
Block a user