mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-25 16:55:19 +08:00
Open sourcing the Ballet environment from "Towards Mental Time Travel..." which offers a challenging environment for remembering sequential events.
PiperOrigin-RevId: 387333851
This commit is contained in:
committed by
Diego de Las Casas
parent
a9dfb59097
commit
837430d4d1
@@ -0,0 +1,54 @@
|
||||
# Perceiver and Perceiver IO
|
||||
|
||||
Perceiver [1] is a general architecture that works on many kinds of data,
|
||||
including images, video, audio, 3D point clouds, language and symbolic inputs,
|
||||
multimodal combinations, etc.
|
||||
Perceivers can handle new types of data with only minimal modifications.
|
||||
Perceivers process inputs using domain-agnostic Transformer-style attention.
|
||||
Unlike Transformers, Perceivers first map inputs to a small latent space where
|
||||
processing is cheap and doesn't depend on the input size.
|
||||
This makes it possible to build very deep networks
|
||||
even when using large inputs like images or videos.
|
||||
|
||||
Perceiver IO [2] is a generalization of Perceiver to handle arbitrary *outputs*
|
||||
in addition to arbitrary inputs.
|
||||
The original Perceiver only produced a single classification label.
|
||||
In addition to classification labels,
|
||||
Perceiver IO can produce (for example) language, optical flow,
|
||||
and multimodal videos with audio.
|
||||
This is done using the same building blocks as the original Perceiver.
|
||||
The computational complexity of Perceiver IO is linear in the input and output
|
||||
size and the bulk of the processing occurs in the latent space,
|
||||
allowing us to process inputs and outputs that are much larger
|
||||
than can be handled by standard Transformers.
|
||||
This means, for example, Perceiver IO can do BERT-style masked language modeling
|
||||
directly using *bytes* instead of tokenized inputs.
|
||||
|
||||
This directory contains our implementation of Perceiver IO
|
||||
(encompassing the original Perceiver as a special case).
|
||||
The `perceiver.py` file contains our implementation of Perceiver IO,
|
||||
and `io_processors.py` contains domain-specific input and output processors
|
||||
for the experiments we ran.
|
||||
We provide example colabs in the `colabs` directory to demonstrate
|
||||
how our models can be used and show the qualitative performance of Perceiver IO.
|
||||
|
||||
## Usage
|
||||
|
||||
First, install dependencies following these instructions:
|
||||
|
||||
1. Create a virtual env: `python3 -m venv ~/.venv/perceiver`
|
||||
2. Switch to the virtual env: `source ~/.venv/perceiver/bin/activate`
|
||||
3. Follow instructions for installing JAX on your platform:
|
||||
https://github.com/google/jax#installation
|
||||
4. Install other dependencies: `pip install -f requirements.txt`
|
||||
|
||||
After install dependencies, you can open the notebooks in the `colabs` directory
|
||||
using Jupyter or Colab.
|
||||
|
||||
## References
|
||||
|
||||
[1] Andrew Jaegle, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals,
|
||||
Joao Carreira.
|
||||
*Perceiver: General Perception with Iterative Attention*. ICML 2021.
|
||||
|
||||
[2] TODO: Add citation after paper is published on ArXiv.
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,222 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Position encodings and utilities."""
|
||||
|
||||
import abc
|
||||
import functools
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_fourier_features(
|
||||
pos, num_bands, max_resolution=(224, 224),
|
||||
concat_pos=True, sine_only=False):
|
||||
"""Generate a Fourier frequency position encoding with linear spacing.
|
||||
|
||||
Args:
|
||||
pos: The position of n points in d dimensional space.
|
||||
A jnp array of shape [n, d].
|
||||
num_bands: The number of bands (K) to use.
|
||||
max_resolution: The maximum resolution (i.e. the number of pixels per dim).
|
||||
A tuple representing resolution for each dimension
|
||||
concat_pos: Concatenate the input position encoding to the Fourier features?
|
||||
sine_only: Whether to use a single phase (sin) or two (sin/cos) for each
|
||||
frequency band.
|
||||
Returns:
|
||||
embedding: A 1D jnp array of shape [n, n_channels]. If concat_pos is True
|
||||
and sine_only is False, output dimensions are ordered as:
|
||||
[dim_1, dim_2, ..., dim_d,
|
||||
sin(pi*f_1*dim_1), ..., sin(pi*f_K*dim_1), ...,
|
||||
sin(pi*f_1*dim_d), ..., sin(pi*f_K*dim_d),
|
||||
cos(pi*f_1*dim_1), ..., cos(pi*f_K*dim_1), ...,
|
||||
cos(pi*f_1*dim_d), ..., cos(pi*f_K*dim_d)],
|
||||
where dim_i is pos[:, i] and f_k is the kth frequency band.
|
||||
"""
|
||||
min_freq = 1.0
|
||||
# Nyquist frequency at the target resolution:
|
||||
|
||||
freq_bands = jnp.stack([
|
||||
jnp.linspace(min_freq, res / 2, num=num_bands, endpoint=True)
|
||||
for res in max_resolution], axis=0)
|
||||
|
||||
# Get frequency bands for each spatial dimension.
|
||||
# Output is size [n, d * num_bands]
|
||||
per_pos_features = pos[:, :, None] * freq_bands[None, :, :]
|
||||
per_pos_features = jnp.reshape(per_pos_features,
|
||||
[-1, np.prod(per_pos_features.shape[1:])])
|
||||
|
||||
if sine_only:
|
||||
# Output is size [n, d * num_bands]
|
||||
per_pos_features = jnp.sin(jnp.pi * (per_pos_features))
|
||||
else:
|
||||
# Output is size [n, 2 * d * num_bands]
|
||||
per_pos_features = jnp.concatenate(
|
||||
[jnp.sin(jnp.pi * per_pos_features),
|
||||
jnp.cos(jnp.pi * per_pos_features)], axis=-1)
|
||||
# Concatenate the raw input positions.
|
||||
if concat_pos:
|
||||
# Adds d bands to the encoding.
|
||||
per_pos_features = jnp.concatenate([pos, per_pos_features], axis=-1)
|
||||
return per_pos_features
|
||||
|
||||
|
||||
def build_linear_positions(index_dims, output_range=(-1.0, 1.0)):
|
||||
"""Generate an array of position indices for an N-D input array.
|
||||
|
||||
Args:
|
||||
index_dims: The shape of the index dimensions of the input array.
|
||||
output_range: The min and max values taken by each input index dimension.
|
||||
Returns:
|
||||
A jnp array of shape [index_dims[0], index_dims[1], .., index_dims[-1], N].
|
||||
"""
|
||||
def _linspace(n_xels_per_dim):
|
||||
return jnp.linspace(
|
||||
output_range[0], output_range[1],
|
||||
num=n_xels_per_dim,
|
||||
endpoint=True, dtype=jnp.float32)
|
||||
|
||||
dim_ranges = [
|
||||
_linspace(n_xels_per_dim) for n_xels_per_dim in index_dims]
|
||||
array_index_grid = jnp.meshgrid(*dim_ranges, indexing='ij')
|
||||
|
||||
return jnp.stack(array_index_grid, axis=-1)
|
||||
|
||||
|
||||
class AbstractPositionEncoding(hk.Module, metaclass=abc.ABCMeta):
|
||||
"""Abstract Perceiver decoder."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, batch_size, pos):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TrainablePositionEncoding(AbstractPositionEncoding):
|
||||
"""Trainable position encoding."""
|
||||
|
||||
def __init__(self, index_dim, num_channels=128, init_scale=0.02, name=None):
|
||||
super(TrainablePositionEncoding, self).__init__(name=name)
|
||||
self._index_dim = index_dim
|
||||
self._num_channels = num_channels
|
||||
self._init_scale = init_scale
|
||||
|
||||
def __call__(self, batch_size, pos=None):
|
||||
del pos # Unused.
|
||||
pos_embs = hk.get_parameter(
|
||||
'pos_embs', [self._index_dim, self._num_channels],
|
||||
init=hk.initializers.TruncatedNormal(stddev=self._init_scale))
|
||||
|
||||
if batch_size is not None:
|
||||
pos_embs = jnp.broadcast_to(
|
||||
pos_embs[None, :, :], (batch_size,) + pos_embs.shape)
|
||||
return pos_embs
|
||||
|
||||
|
||||
def _check_or_build_spatial_positions(pos, index_dims, batch_size):
|
||||
"""Checks or builds spatial position features (x, y, ...).
|
||||
|
||||
Args:
|
||||
pos: None, or an array of position features. If None, position features
|
||||
are built. Otherwise, their size is checked.
|
||||
index_dims: An iterable giving the spatial/index size of the data to be
|
||||
featurized.
|
||||
batch_size: The batch size of the data to be featurized.
|
||||
Returns:
|
||||
An array of position features, of shape [batch_size, prod(index_dims)].
|
||||
"""
|
||||
if pos is None:
|
||||
pos = build_linear_positions(index_dims)
|
||||
pos = jnp.broadcast_to(pos[None], (batch_size,) + pos.shape)
|
||||
pos = jnp.reshape(pos, [batch_size, np.prod(index_dims), -1])
|
||||
else:
|
||||
# Just a warning label: you probably don't want your spatial features to
|
||||
# have a different spatial layout than your pos coordinate system.
|
||||
# But feel free to override if you think it'll work!
|
||||
assert pos.shape[-1] == len(index_dims)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
class FourierPositionEncoding(AbstractPositionEncoding):
|
||||
"""Fourier (Sinusoidal) position encoding."""
|
||||
|
||||
def __init__(self, index_dims, num_bands, concat_pos=True,
|
||||
max_resolution=None, sine_only=False, name=None):
|
||||
super(FourierPositionEncoding, self).__init__(name=name)
|
||||
self._num_bands = num_bands
|
||||
self._concat_pos = concat_pos
|
||||
self._sine_only = sine_only
|
||||
self._index_dims = index_dims
|
||||
# Use the index dims as the maximum resolution if it's not provided.
|
||||
self._max_resolution = max_resolution or index_dims
|
||||
|
||||
def __call__(self, batch_size, pos=None):
|
||||
pos = _check_or_build_spatial_positions(pos, self._index_dims, batch_size)
|
||||
build_ff_fn = functools.partial(
|
||||
generate_fourier_features,
|
||||
num_bands=self._num_bands,
|
||||
max_resolution=self._max_resolution,
|
||||
concat_pos=self._concat_pos,
|
||||
sine_only=self._sine_only)
|
||||
return jax.vmap(build_ff_fn, 0, 0)(pos)
|
||||
|
||||
|
||||
class PositionEncodingProjector(AbstractPositionEncoding):
|
||||
"""Projects a position encoding to a target size."""
|
||||
|
||||
def __init__(self, output_size, base_position_encoding, name=None):
|
||||
super(PositionEncodingProjector, self).__init__(name=name)
|
||||
self._output_size = output_size
|
||||
self._base_position_encoding = base_position_encoding
|
||||
|
||||
def __call__(self, batch_size, pos=None):
|
||||
base_pos = self._base_position_encoding(batch_size, pos)
|
||||
projected_pos = hk.Linear(output_size=self._output_size)(base_pos)
|
||||
return projected_pos
|
||||
|
||||
|
||||
def build_position_encoding(
|
||||
position_encoding_type,
|
||||
index_dims,
|
||||
project_pos_dim=-1,
|
||||
trainable_position_encoding_kwargs=None,
|
||||
fourier_position_encoding_kwargs=None,
|
||||
name=None):
|
||||
"""Builds the position encoding."""
|
||||
|
||||
if position_encoding_type == 'trainable':
|
||||
assert trainable_position_encoding_kwargs is not None
|
||||
output_pos_enc = TrainablePositionEncoding(
|
||||
# Construct 1D features:
|
||||
index_dim=np.prod(index_dims),
|
||||
name=name,
|
||||
**trainable_position_encoding_kwargs)
|
||||
elif position_encoding_type == 'fourier':
|
||||
assert fourier_position_encoding_kwargs is not None
|
||||
output_pos_enc = FourierPositionEncoding(
|
||||
index_dims=index_dims,
|
||||
name=name,
|
||||
**fourier_position_encoding_kwargs)
|
||||
else:
|
||||
raise ValueError(f'Unknown position encoding: {position_encoding_type}.')
|
||||
|
||||
if project_pos_dim > 0:
|
||||
# Project the position encoding to a target dimension:
|
||||
output_pos_enc = PositionEncodingProjector(
|
||||
output_size=project_pos_dim,
|
||||
base_position_encoding=output_pos_enc)
|
||||
|
||||
return output_pos_enc
|
||||
@@ -0,0 +1,16 @@
|
||||
absl-py==0.13.0
|
||||
dill==0.3.4
|
||||
dm-haiku==0.0.4
|
||||
einops==0.3.0
|
||||
flatbuffers==2.0
|
||||
imageio==2.9.0
|
||||
immutabledict==2.0.0
|
||||
jax==0.2.16
|
||||
jaxlib==0.1.68+cuda111
|
||||
numpy==1.21.0
|
||||
opencv-python==4.5.2.54
|
||||
opt-einsum==3.3.0
|
||||
Pillow==8.3.1
|
||||
scipy==1.7.0
|
||||
six==1.16.0
|
||||
tabulate==0.8.9
|
||||
Reference in New Issue
Block a user