From 08ed21473d4e442358c78d4ef81302930cb569fd Mon Sep 17 00:00:00 2001 From: Kyle Taylor Date: Thu, 25 Mar 2021 18:17:46 +0000 Subject: [PATCH] Switch types to prevent types switching to doubles. PiperOrigin-RevId: 365082831 --- enformer/attention_module.py | 6 +++--- enformer/enformer.py | 2 +- enformer/enformer_test.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/enformer/attention_module.py b/enformer/attention_module.py index 9383ddd..25434e3 100644 --- a/enformer/attention_module.py +++ b/enformer/attention_module.py @@ -197,11 +197,11 @@ class MultiheadAttention(snt.Module): w_init=self._initializer) self._r_w_bias = tf.Variable( self._initializer([1, self._num_heads, 1, self._key_size], - dtype=tf.float64), + dtype=tf.float32), name='r_w_bias') self._r_r_bias = tf.Variable( self._initializer([1, self._num_heads, 1, self._key_size], - dtype=tf.float64), + dtype=tf.float32), name='r_r_bias') def _multihead_output(self, linear, inputs): @@ -254,7 +254,7 @@ class MultiheadAttention(snt.Module): content_logits = tf.matmul(q + self._r_w_bias, k, transpose_b=True) # [B, H, T', 2T-1] relative_logits = tf.matmul( - q + self._r_r_bias, tf.cast(r_k, tf.float64), transpose_b=True) + q + self._r_r_bias, r_k, transpose_b=True) # [B, H, T', T] relative_logits = relative_shift(relative_logits) logits = content_logits + relative_logits diff --git a/enformer/enformer.py b/enformer/enformer.py index 4be87f0..40e5ec1 100644 --- a/enformer/enformer.py +++ b/enformer/enformer.py @@ -289,7 +289,7 @@ def one_hot_encode(sequence: str, alphabet: str = 'ACGT', neutral_alphabet: str = 'N', neutral_value: Any = 0, - dtype=np.float64) -> np.ndarray: + dtype=np.float32) -> np.ndarray: """One-hot encode sequence.""" def to_uint8(string): return np.frombuffer(string.encode('ascii'), dtype=np.uint8) diff --git a/enformer/enformer_test.py b/enformer/enformer_test.py index 21b5777..69c2a66 100644 --- a/enformer/enformer_test.py +++ b/enformer/enformer_test.py @@ -38,7 +38,7 @@ class TestEnformer(unittest.TestCase): def _get_random_input(): seq = ''.join( [random.choice('ACGT') for _ in range(enformer.SEQUENCE_LENGTH)]) - return np.expand_dims(enformer.one_hot_encode(seq), 0) + return np.expand_dims(enformer.one_hot_encode(seq), 0).astype(np.float32) if __name__ == '__main__':