Switch types to prevent types switching to doubles.

PiperOrigin-RevId: 365082831
This commit is contained in:
Kyle Taylor
2021-03-25 18:17:46 +00:00
committed by Louise Deason
parent d9fb969489
commit 08ed21473d
3 changed files with 5 additions and 5 deletions

View File

@@ -197,11 +197,11 @@ class MultiheadAttention(snt.Module):
w_init=self._initializer)
self._r_w_bias = tf.Variable(
self._initializer([1, self._num_heads, 1, self._key_size],
dtype=tf.float64),
dtype=tf.float32),
name='r_w_bias')
self._r_r_bias = tf.Variable(
self._initializer([1, self._num_heads, 1, self._key_size],
dtype=tf.float64),
dtype=tf.float32),
name='r_r_bias')
def _multihead_output(self, linear, inputs):
@@ -254,7 +254,7 @@ class MultiheadAttention(snt.Module):
content_logits = tf.matmul(q + self._r_w_bias, k, transpose_b=True)
# [B, H, T', 2T-1]
relative_logits = tf.matmul(
q + self._r_r_bias, tf.cast(r_k, tf.float64), transpose_b=True)
q + self._r_r_bias, r_k, transpose_b=True)
# [B, H, T', T]
relative_logits = relative_shift(relative_logits)
logits = content_logits + relative_logits

View File

@@ -289,7 +289,7 @@ def one_hot_encode(sequence: str,
alphabet: str = 'ACGT',
neutral_alphabet: str = 'N',
neutral_value: Any = 0,
dtype=np.float64) -> np.ndarray:
dtype=np.float32) -> np.ndarray:
"""One-hot encode sequence."""
def to_uint8(string):
return np.frombuffer(string.encode('ascii'), dtype=np.uint8)

View File

@@ -38,7 +38,7 @@ class TestEnformer(unittest.TestCase):
def _get_random_input():
seq = ''.join(
[random.choice('ACGT') for _ in range(enformer.SEQUENCE_LENGTH)])
return np.expand_dims(enformer.one_hot_encode(seq), 0)
return np.expand_dims(enformer.one_hot_encode(seq), 0).astype(np.float32)
if __name__ == '__main__':