Add split_rng=False (current default) to HTM.

Haiku plans to make split_rng a required argument to hk.vmap in an upcoming
release. This change updates HTM to preserve the current behaviour. We
also handle the case where users are using a release of Haiku without the
split_rng option, for these users split_rng=False is implied.

PiperOrigin-RevId: 428454975
This commit is contained in:
Tom Hennigan
2022-02-14 10:33:11 +00:00
committed by Diego de Las Casas
parent 840bfe86a6
commit 7b17dd5dde

View File

@@ -14,6 +14,8 @@
"""Haiku module implementing hierarchical attention over memory."""
import functools
import inspect
from typing import Optional, NamedTuple
import chex
@@ -198,7 +200,7 @@ class HierarchicalMemoryAttention(hk.Module):
key=sub_sub_top_k_contents,
value=sub_sub_top_k_contents)
return sub_attention_results
do_attention = hk.vmap(do_attention, in_axes=0)
do_attention = hk_vmap(do_attention, in_axes=0, split_rng=False)
attention_results = do_attention(sub_inputs, top_k_contents)
attention_results = jnp.squeeze(attention_results, axis=2)
# Now collapse results across k memories
@@ -207,8 +209,8 @@ class HierarchicalMemoryAttention(hk.Module):
return attention_results
# vmap across batch
batch_within_memory_attention = hk.vmap(_within_memory_attention,
in_axes=0)
batch_within_memory_attention = hk_vmap(_within_memory_attention,
in_axes=0, split_rng=False)
outputs = batch_within_memory_attention(
queries,
jax.lax.stop_gradient(augmented_contents),
@@ -216,3 +218,18 @@ class HierarchicalMemoryAttention(hk.Module):
top_k_indices)
return outputs
@functools.wraps(hk.vmap)
def hk_vmap(*args, **kwargs):
"""Helper function to support older versions of Haiku."""
# Older versions of Haiku did not have split_rng, but the behavior has always
# been equivalent to split_rng=False.
if "split_rng" not in inspect.signature(hk.vmap).parameters:
kwargs.setdefault("split_rng", False)
if kwargs.get["split_rng"]:
raise ValueError("The installed version of Haiku only supports "
"`split_rng=False`, please upgrade Haiku.")
del kwargs["split_rng"]
return hk.vmap(*args, **kwargs)