mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-28 19:31:14 +08:00
Minor changes to match trained SOTA model.
PiperOrigin-RevId: 450716647
This commit is contained in:
@@ -453,7 +453,8 @@ def positional_features_gamma(positions: tf.Tensor,
|
|||||||
tf.abs(tf.cast(positions, dtype=tf.float32))[..., tf.newaxis],
|
tf.abs(tf.cast(positions, dtype=tf.float32))[..., tf.newaxis],
|
||||||
concentration, rate)
|
concentration, rate)
|
||||||
probabilities += 1e-8 # To ensure numerical stability.
|
probabilities += 1e-8 # To ensure numerical stability.
|
||||||
outputs = probabilities / tf.reduce_max(probabilities)
|
outputs = probabilities / tf.reduce_max(probabilities,
|
||||||
|
axis=1, keepdims=True)
|
||||||
tf.TensorShape(outputs.shape).assert_is_compatible_with(
|
tf.TensorShape(outputs.shape).assert_is_compatible_with(
|
||||||
positions.shape + [feature_size])
|
positions.shape + [feature_size])
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
+16
-9
@@ -87,10 +87,12 @@ class Enformer(snt.Module):
|
|||||||
# lambda is used in Sequential to construct the module under tf.name_scope.
|
# lambda is used in Sequential to construct the module under tf.name_scope.
|
||||||
def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs):
|
def conv_block(filters, width=1, w_init=None, name='conv_block', **kwargs):
|
||||||
return Sequential(lambda: [
|
return Sequential(lambda: [
|
||||||
snt.BatchNorm(create_scale=True,
|
snt.distribute.CrossReplicaBatchNorm(
|
||||||
create_offset=True,
|
create_scale=True,
|
||||||
decay_rate=0.9,
|
create_offset=True,
|
||||||
scale_init=snt.initializers.Ones()),
|
scale_init=snt.initializers.Ones(),
|
||||||
|
moving_mean=snt.ExponentialMovingAverage(0.9),
|
||||||
|
moving_variance=snt.ExponentialMovingAverage(0.9)),
|
||||||
gelu,
|
gelu,
|
||||||
snt.Conv1D(filters, width, w_init=w_init, **kwargs)
|
snt.Conv1D(filters, width, w_init=w_init, **kwargs)
|
||||||
], name=name)
|
], name=name)
|
||||||
@@ -184,16 +186,22 @@ class Enformer(snt.Module):
|
|||||||
class TargetLengthCrop1D(snt.Module):
|
class TargetLengthCrop1D(snt.Module):
|
||||||
"""Crop sequence to match the desired target length."""
|
"""Crop sequence to match the desired target length."""
|
||||||
|
|
||||||
def __init__(self, target_length: int, name='target_length_crop'):
|
def __init__(self,
|
||||||
|
target_length: Optional[int],
|
||||||
|
name: str = 'target_length_crop'):
|
||||||
super().__init__(name=name)
|
super().__init__(name=name)
|
||||||
self._target_length = target_length
|
self._target_length = target_length
|
||||||
|
|
||||||
def __call__(self, inputs):
|
def __call__(self, inputs):
|
||||||
|
if self._target_length is None:
|
||||||
|
return inputs
|
||||||
trim = (inputs.shape[-2] - self._target_length) // 2
|
trim = (inputs.shape[-2] - self._target_length) // 2
|
||||||
if trim < 0:
|
if trim < 0:
|
||||||
raise ValueError('inputs longer than target length')
|
raise ValueError('inputs longer than target length')
|
||||||
|
elif trim == 0:
|
||||||
return inputs[..., trim:-trim, :]
|
return inputs
|
||||||
|
else:
|
||||||
|
return inputs[..., trim:-trim, :]
|
||||||
|
|
||||||
|
|
||||||
class Sequential(snt.Module):
|
class Sequential(snt.Module):
|
||||||
@@ -209,8 +217,7 @@ class Sequential(snt.Module):
|
|||||||
else:
|
else:
|
||||||
# layers wrapped in a lambda function to have a common namespace.
|
# layers wrapped in a lambda function to have a common namespace.
|
||||||
if hasattr(layers, '__call__'):
|
if hasattr(layers, '__call__'):
|
||||||
with tf.name_scope(name):
|
layers = layers()
|
||||||
layers = layers()
|
|
||||||
self._layers = [layer for layer in layers if layer is not None]
|
self._layers = [layer for layer in layers if layer is not None]
|
||||||
|
|
||||||
def __call__(self, inputs: tf.Tensor, is_training: bool, **kwargs):
|
def __call__(self, inputs: tf.Tensor, is_training: bool, **kwargs):
|
||||||
|
|||||||
@@ -2,5 +2,5 @@ dm-sonnet==2.0.0
|
|||||||
kipoiseq==0.5.2
|
kipoiseq==0.5.2
|
||||||
numpy==1.19.5
|
numpy==1.19.5
|
||||||
pandas==1.2.3
|
pandas==1.2.3
|
||||||
tensorflow==2.4.1
|
tensorflow==2.5.0
|
||||||
tensorflow-hub==0.11.0
|
tensorflow-hub==0.11.0
|
||||||
|
|||||||
Reference in New Issue
Block a user