mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Add split_rng=False (current default) to HTM.
Haiku plans to make split_rng a required argument to hk.vmap in an upcoming release. This change updates HTM to preserve the current behaviour. We also handle the case where users are using a release of Haiku without the split_rng option, for these users split_rng=False is implied. PiperOrigin-RevId: 428454975
This commit is contained in:
committed by
Diego de Las Casas
parent
840bfe86a6
commit
7b17dd5dde
@@ -14,6 +14,8 @@
|
||||
|
||||
"""Haiku module implementing hierarchical attention over memory."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
from typing import Optional, NamedTuple
|
||||
|
||||
import chex
|
||||
@@ -198,7 +200,7 @@ class HierarchicalMemoryAttention(hk.Module):
|
||||
key=sub_sub_top_k_contents,
|
||||
value=sub_sub_top_k_contents)
|
||||
return sub_attention_results
|
||||
do_attention = hk.vmap(do_attention, in_axes=0)
|
||||
do_attention = hk_vmap(do_attention, in_axes=0, split_rng=False)
|
||||
attention_results = do_attention(sub_inputs, top_k_contents)
|
||||
attention_results = jnp.squeeze(attention_results, axis=2)
|
||||
# Now collapse results across k memories
|
||||
@@ -207,8 +209,8 @@ class HierarchicalMemoryAttention(hk.Module):
|
||||
return attention_results
|
||||
|
||||
# vmap across batch
|
||||
batch_within_memory_attention = hk.vmap(_within_memory_attention,
|
||||
in_axes=0)
|
||||
batch_within_memory_attention = hk_vmap(_within_memory_attention,
|
||||
in_axes=0, split_rng=False)
|
||||
outputs = batch_within_memory_attention(
|
||||
queries,
|
||||
jax.lax.stop_gradient(augmented_contents),
|
||||
@@ -216,3 +218,18 @@ class HierarchicalMemoryAttention(hk.Module):
|
||||
top_k_indices)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
@functools.wraps(hk.vmap)
|
||||
def hk_vmap(*args, **kwargs):
|
||||
"""Helper function to support older versions of Haiku."""
|
||||
# Older versions of Haiku did not have split_rng, but the behavior has always
|
||||
# been equivalent to split_rng=False.
|
||||
if "split_rng" not in inspect.signature(hk.vmap).parameters:
|
||||
kwargs.setdefault("split_rng", False)
|
||||
if kwargs.get["split_rng"]:
|
||||
raise ValueError("The installed version of Haiku only supports "
|
||||
"`split_rng=False`, please upgrade Haiku.")
|
||||
del kwargs["split_rng"]
|
||||
|
||||
return hk.vmap(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user