mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-27 10:15:44 +08:00
Open sourcing the Ballet environment from "Towards Mental Time Travel..." which offers a challenging environment for remembering sequential events.
PiperOrigin-RevId: 387333851
This commit is contained in:
committed by
Diego de Las Casas
parent
a9dfb59097
commit
837430d4d1
@@ -0,0 +1,54 @@
|
|||||||
|
# Perceiver and Perceiver IO
|
||||||
|
|
||||||
|
Perceiver [1] is a general architecture that works on many kinds of data,
|
||||||
|
including images, video, audio, 3D point clouds, language and symbolic inputs,
|
||||||
|
multimodal combinations, etc.
|
||||||
|
Perceivers can handle new types of data with only minimal modifications.
|
||||||
|
Perceivers process inputs using domain-agnostic Transformer-style attention.
|
||||||
|
Unlike Transformers, Perceivers first map inputs to a small latent space where
|
||||||
|
processing is cheap and doesn't depend on the input size.
|
||||||
|
This makes it possible to build very deep networks
|
||||||
|
even when using large inputs like images or videos.
|
||||||
|
|
||||||
|
Perceiver IO [2] is a generalization of Perceiver to handle arbitrary *outputs*
|
||||||
|
in addition to arbitrary inputs.
|
||||||
|
The original Perceiver only produced a single classification label.
|
||||||
|
In addition to classification labels,
|
||||||
|
Perceiver IO can produce (for example) language, optical flow,
|
||||||
|
and multimodal videos with audio.
|
||||||
|
This is done using the same building blocks as the original Perceiver.
|
||||||
|
The computational complexity of Perceiver IO is linear in the input and output
|
||||||
|
size and the bulk of the processing occurs in the latent space,
|
||||||
|
allowing us to process inputs and outputs that are much larger
|
||||||
|
than can be handled by standard Transformers.
|
||||||
|
This means, for example, Perceiver IO can do BERT-style masked language modeling
|
||||||
|
directly using *bytes* instead of tokenized inputs.
|
||||||
|
|
||||||
|
This directory contains our implementation of Perceiver IO
|
||||||
|
(encompassing the original Perceiver as a special case).
|
||||||
|
The `perceiver.py` file contains our implementation of Perceiver IO,
|
||||||
|
and `io_processors.py` contains domain-specific input and output processors
|
||||||
|
for the experiments we ran.
|
||||||
|
We provide example colabs in the `colabs` directory to demonstrate
|
||||||
|
how our models can be used and show the qualitative performance of Perceiver IO.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
First, install dependencies following these instructions:
|
||||||
|
|
||||||
|
1. Create a virtual env: `python3 -m venv ~/.venv/perceiver`
|
||||||
|
2. Switch to the virtual env: `source ~/.venv/perceiver/bin/activate`
|
||||||
|
3. Follow instructions for installing JAX on your platform:
|
||||||
|
https://github.com/google/jax#installation
|
||||||
|
4. Install other dependencies: `pip install -f requirements.txt`
|
||||||
|
|
||||||
|
After install dependencies, you can open the notebooks in the `colabs` directory
|
||||||
|
using Jupyter or Colab.
|
||||||
|
|
||||||
|
## References
|
||||||
|
|
||||||
|
[1] Andrew Jaegle, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals,
|
||||||
|
Joao Carreira.
|
||||||
|
*Perceiver: General Perception with Iterative Attention*. ICML 2021.
|
||||||
|
|
||||||
|
[2] TODO: Add citation after paper is published on ArXiv.
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,222 @@
|
|||||||
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Position encodings and utilities."""
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import functools
|
||||||
|
|
||||||
|
import haiku as hk
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def generate_fourier_features(
|
||||||
|
pos, num_bands, max_resolution=(224, 224),
|
||||||
|
concat_pos=True, sine_only=False):
|
||||||
|
"""Generate a Fourier frequency position encoding with linear spacing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pos: The position of n points in d dimensional space.
|
||||||
|
A jnp array of shape [n, d].
|
||||||
|
num_bands: The number of bands (K) to use.
|
||||||
|
max_resolution: The maximum resolution (i.e. the number of pixels per dim).
|
||||||
|
A tuple representing resolution for each dimension
|
||||||
|
concat_pos: Concatenate the input position encoding to the Fourier features?
|
||||||
|
sine_only: Whether to use a single phase (sin) or two (sin/cos) for each
|
||||||
|
frequency band.
|
||||||
|
Returns:
|
||||||
|
embedding: A 1D jnp array of shape [n, n_channels]. If concat_pos is True
|
||||||
|
and sine_only is False, output dimensions are ordered as:
|
||||||
|
[dim_1, dim_2, ..., dim_d,
|
||||||
|
sin(pi*f_1*dim_1), ..., sin(pi*f_K*dim_1), ...,
|
||||||
|
sin(pi*f_1*dim_d), ..., sin(pi*f_K*dim_d),
|
||||||
|
cos(pi*f_1*dim_1), ..., cos(pi*f_K*dim_1), ...,
|
||||||
|
cos(pi*f_1*dim_d), ..., cos(pi*f_K*dim_d)],
|
||||||
|
where dim_i is pos[:, i] and f_k is the kth frequency band.
|
||||||
|
"""
|
||||||
|
min_freq = 1.0
|
||||||
|
# Nyquist frequency at the target resolution:
|
||||||
|
|
||||||
|
freq_bands = jnp.stack([
|
||||||
|
jnp.linspace(min_freq, res / 2, num=num_bands, endpoint=True)
|
||||||
|
for res in max_resolution], axis=0)
|
||||||
|
|
||||||
|
# Get frequency bands for each spatial dimension.
|
||||||
|
# Output is size [n, d * num_bands]
|
||||||
|
per_pos_features = pos[:, :, None] * freq_bands[None, :, :]
|
||||||
|
per_pos_features = jnp.reshape(per_pos_features,
|
||||||
|
[-1, np.prod(per_pos_features.shape[1:])])
|
||||||
|
|
||||||
|
if sine_only:
|
||||||
|
# Output is size [n, d * num_bands]
|
||||||
|
per_pos_features = jnp.sin(jnp.pi * (per_pos_features))
|
||||||
|
else:
|
||||||
|
# Output is size [n, 2 * d * num_bands]
|
||||||
|
per_pos_features = jnp.concatenate(
|
||||||
|
[jnp.sin(jnp.pi * per_pos_features),
|
||||||
|
jnp.cos(jnp.pi * per_pos_features)], axis=-1)
|
||||||
|
# Concatenate the raw input positions.
|
||||||
|
if concat_pos:
|
||||||
|
# Adds d bands to the encoding.
|
||||||
|
per_pos_features = jnp.concatenate([pos, per_pos_features], axis=-1)
|
||||||
|
return per_pos_features
|
||||||
|
|
||||||
|
|
||||||
|
def build_linear_positions(index_dims, output_range=(-1.0, 1.0)):
|
||||||
|
"""Generate an array of position indices for an N-D input array.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
index_dims: The shape of the index dimensions of the input array.
|
||||||
|
output_range: The min and max values taken by each input index dimension.
|
||||||
|
Returns:
|
||||||
|
A jnp array of shape [index_dims[0], index_dims[1], .., index_dims[-1], N].
|
||||||
|
"""
|
||||||
|
def _linspace(n_xels_per_dim):
|
||||||
|
return jnp.linspace(
|
||||||
|
output_range[0], output_range[1],
|
||||||
|
num=n_xels_per_dim,
|
||||||
|
endpoint=True, dtype=jnp.float32)
|
||||||
|
|
||||||
|
dim_ranges = [
|
||||||
|
_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims]
|
||||||
|
array_index_grid = jnp.meshgrid(*dim_ranges, indexing='ij')
|
||||||
|
|
||||||
|
return jnp.stack(array_index_grid, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractPositionEncoding(hk.Module, metaclass=abc.ABCMeta):
|
||||||
|
"""Abstract Perceiver decoder."""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __call__(self, batch_size, pos):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TrainablePositionEncoding(AbstractPositionEncoding):
|
||||||
|
"""Trainable position encoding."""
|
||||||
|
|
||||||
|
def __init__(self, index_dim, num_channels=128, init_scale=0.02, name=None):
|
||||||
|
super(TrainablePositionEncoding, self).__init__(name=name)
|
||||||
|
self._index_dim = index_dim
|
||||||
|
self._num_channels = num_channels
|
||||||
|
self._init_scale = init_scale
|
||||||
|
|
||||||
|
def __call__(self, batch_size, pos=None):
|
||||||
|
del pos # Unused.
|
||||||
|
pos_embs = hk.get_parameter(
|
||||||
|
'pos_embs', [self._index_dim, self._num_channels],
|
||||||
|
init=hk.initializers.TruncatedNormal(stddev=self._init_scale))
|
||||||
|
|
||||||
|
if batch_size is not None:
|
||||||
|
pos_embs = jnp.broadcast_to(
|
||||||
|
pos_embs[None, :, :], (batch_size,) + pos_embs.shape)
|
||||||
|
return pos_embs
|
||||||
|
|
||||||
|
|
||||||
|
def _check_or_build_spatial_positions(pos, index_dims, batch_size):
|
||||||
|
"""Checks or builds spatial position features (x, y, ...).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pos: None, or an array of position features. If None, position features
|
||||||
|
are built. Otherwise, their size is checked.
|
||||||
|
index_dims: An iterable giving the spatial/index size of the data to be
|
||||||
|
featurized.
|
||||||
|
batch_size: The batch size of the data to be featurized.
|
||||||
|
Returns:
|
||||||
|
An array of position features, of shape [batch_size, prod(index_dims)].
|
||||||
|
"""
|
||||||
|
if pos is None:
|
||||||
|
pos = build_linear_positions(index_dims)
|
||||||
|
pos = jnp.broadcast_to(pos[None], (batch_size,) + pos.shape)
|
||||||
|
pos = jnp.reshape(pos, [batch_size, np.prod(index_dims), -1])
|
||||||
|
else:
|
||||||
|
# Just a warning label: you probably don't want your spatial features to
|
||||||
|
# have a different spatial layout than your pos coordinate system.
|
||||||
|
# But feel free to override if you think it'll work!
|
||||||
|
assert pos.shape[-1] == len(index_dims)
|
||||||
|
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
class FourierPositionEncoding(AbstractPositionEncoding):
|
||||||
|
"""Fourier (Sinusoidal) position encoding."""
|
||||||
|
|
||||||
|
def __init__(self, index_dims, num_bands, concat_pos=True,
|
||||||
|
max_resolution=None, sine_only=False, name=None):
|
||||||
|
super(FourierPositionEncoding, self).__init__(name=name)
|
||||||
|
self._num_bands = num_bands
|
||||||
|
self._concat_pos = concat_pos
|
||||||
|
self._sine_only = sine_only
|
||||||
|
self._index_dims = index_dims
|
||||||
|
# Use the index dims as the maximum resolution if it's not provided.
|
||||||
|
self._max_resolution = max_resolution or index_dims
|
||||||
|
|
||||||
|
def __call__(self, batch_size, pos=None):
|
||||||
|
pos = _check_or_build_spatial_positions(pos, self._index_dims, batch_size)
|
||||||
|
build_ff_fn = functools.partial(
|
||||||
|
generate_fourier_features,
|
||||||
|
num_bands=self._num_bands,
|
||||||
|
max_resolution=self._max_resolution,
|
||||||
|
concat_pos=self._concat_pos,
|
||||||
|
sine_only=self._sine_only)
|
||||||
|
return jax.vmap(build_ff_fn, 0, 0)(pos)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEncodingProjector(AbstractPositionEncoding):
|
||||||
|
"""Projects a position encoding to a target size."""
|
||||||
|
|
||||||
|
def __init__(self, output_size, base_position_encoding, name=None):
|
||||||
|
super(PositionEncodingProjector, self).__init__(name=name)
|
||||||
|
self._output_size = output_size
|
||||||
|
self._base_position_encoding = base_position_encoding
|
||||||
|
|
||||||
|
def __call__(self, batch_size, pos=None):
|
||||||
|
base_pos = self._base_position_encoding(batch_size, pos)
|
||||||
|
projected_pos = hk.Linear(output_size=self._output_size)(base_pos)
|
||||||
|
return projected_pos
|
||||||
|
|
||||||
|
|
||||||
|
def build_position_encoding(
|
||||||
|
position_encoding_type,
|
||||||
|
index_dims,
|
||||||
|
project_pos_dim=-1,
|
||||||
|
trainable_position_encoding_kwargs=None,
|
||||||
|
fourier_position_encoding_kwargs=None,
|
||||||
|
name=None):
|
||||||
|
"""Builds the position encoding."""
|
||||||
|
|
||||||
|
if position_encoding_type == 'trainable':
|
||||||
|
assert trainable_position_encoding_kwargs is not None
|
||||||
|
output_pos_enc = TrainablePositionEncoding(
|
||||||
|
# Construct 1D features:
|
||||||
|
index_dim=np.prod(index_dims),
|
||||||
|
name=name,
|
||||||
|
**trainable_position_encoding_kwargs)
|
||||||
|
elif position_encoding_type == 'fourier':
|
||||||
|
assert fourier_position_encoding_kwargs is not None
|
||||||
|
output_pos_enc = FourierPositionEncoding(
|
||||||
|
index_dims=index_dims,
|
||||||
|
name=name,
|
||||||
|
**fourier_position_encoding_kwargs)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unknown position encoding: {position_encoding_type}.')
|
||||||
|
|
||||||
|
if project_pos_dim > 0:
|
||||||
|
# Project the position encoding to a target dimension:
|
||||||
|
output_pos_enc = PositionEncodingProjector(
|
||||||
|
output_size=project_pos_dim,
|
||||||
|
base_position_encoding=output_pos_enc)
|
||||||
|
|
||||||
|
return output_pos_enc
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
absl-py==0.13.0
|
||||||
|
dill==0.3.4
|
||||||
|
dm-haiku==0.0.4
|
||||||
|
einops==0.3.0
|
||||||
|
flatbuffers==2.0
|
||||||
|
imageio==2.9.0
|
||||||
|
immutabledict==2.0.0
|
||||||
|
jax==0.2.16
|
||||||
|
jaxlib==0.1.68+cuda111
|
||||||
|
numpy==1.21.0
|
||||||
|
opencv-python==4.5.2.54
|
||||||
|
opt-einsum==3.3.0
|
||||||
|
Pillow==8.3.1
|
||||||
|
scipy==1.7.0
|
||||||
|
six==1.16.0
|
||||||
|
tabulate==0.8.9
|
||||||
Reference in New Issue
Block a user