mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Switch types to prevent types switching to doubles.
PiperOrigin-RevId: 365082831
This commit is contained in:
committed by
Louise Deason
parent
d9fb969489
commit
08ed21473d
@@ -197,11 +197,11 @@ class MultiheadAttention(snt.Module):
|
||||
w_init=self._initializer)
|
||||
self._r_w_bias = tf.Variable(
|
||||
self._initializer([1, self._num_heads, 1, self._key_size],
|
||||
dtype=tf.float64),
|
||||
dtype=tf.float32),
|
||||
name='r_w_bias')
|
||||
self._r_r_bias = tf.Variable(
|
||||
self._initializer([1, self._num_heads, 1, self._key_size],
|
||||
dtype=tf.float64),
|
||||
dtype=tf.float32),
|
||||
name='r_r_bias')
|
||||
|
||||
def _multihead_output(self, linear, inputs):
|
||||
@@ -254,7 +254,7 @@ class MultiheadAttention(snt.Module):
|
||||
content_logits = tf.matmul(q + self._r_w_bias, k, transpose_b=True)
|
||||
# [B, H, T', 2T-1]
|
||||
relative_logits = tf.matmul(
|
||||
q + self._r_r_bias, tf.cast(r_k, tf.float64), transpose_b=True)
|
||||
q + self._r_r_bias, r_k, transpose_b=True)
|
||||
# [B, H, T', T]
|
||||
relative_logits = relative_shift(relative_logits)
|
||||
logits = content_logits + relative_logits
|
||||
|
||||
@@ -289,7 +289,7 @@ def one_hot_encode(sequence: str,
|
||||
alphabet: str = 'ACGT',
|
||||
neutral_alphabet: str = 'N',
|
||||
neutral_value: Any = 0,
|
||||
dtype=np.float64) -> np.ndarray:
|
||||
dtype=np.float32) -> np.ndarray:
|
||||
"""One-hot encode sequence."""
|
||||
def to_uint8(string):
|
||||
return np.frombuffer(string.encode('ascii'), dtype=np.uint8)
|
||||
|
||||
@@ -38,7 +38,7 @@ class TestEnformer(unittest.TestCase):
|
||||
def _get_random_input():
|
||||
seq = ''.join(
|
||||
[random.choice('ACGT') for _ in range(enformer.SEQUENCE_LENGTH)])
|
||||
return np.expand_dims(enformer.one_hot_encode(seq), 0)
|
||||
return np.expand_dims(enformer.one_hot_encode(seq), 0).astype(np.float32)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user