Files
deepmind-research/polygen/modules.py
Florent Altché 8457046b2c Add checkpoints from the ablation study.
PiperOrigin-RevId: 328023346
2020-08-26 16:54:56 +01:00

1507 lines
55 KiB
Python

# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Modules and networks for mesh generation."""
import sonnet as snt
from tensor2tensor.layers import common_attention
from tensor2tensor.layers import common_layers
import tensorflow.compat.v1 as tf
from tensorflow.python.framework import function
import tensorflow_probability as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
def dequantize_verts(verts, n_bits, add_noise=False):
"""Quantizes vertices and outputs integers with specified n_bits."""
min_range = -0.5
max_range = 0.5
range_quantize = 2**n_bits - 1
verts = tf.cast(verts, tf.float32)
verts = verts * (max_range - min_range) / range_quantize + min_range
if add_noise:
verts += tf.random_uniform(tf.shape(verts)) * (1 / float(range_quantize))
return verts
def quantize_verts(verts, n_bits):
"""Dequantizes integer vertices to floats."""
min_range = -0.5
max_range = 0.5
range_quantize = 2**n_bits - 1
verts_quantize = (
(verts - min_range) * range_quantize / (max_range - min_range))
return tf.cast(verts_quantize, tf.int32)
def top_k_logits(logits, k):
"""Masks logits such that logits not in top-k are small."""
if k == 0:
return logits
else:
values, _ = tf.math.top_k(logits, k=k)
k_largest = tf.reduce_min(values)
logits = tf.where(tf.less_equal(logits, k_largest),
tf.ones_like(logits)*-1e9, logits)
return logits
def top_p_logits(logits, p):
"""Masks logits using nucleus (top-p) sampling."""
if p == 1:
return logits
else:
logit_shape = tf.shape(logits)
seq, dim = logit_shape[1], logit_shape[2]
logits = tf.reshape(logits, [-1, dim])
sort_indices = tf.argsort(logits, axis=-1, direction='DESCENDING')
probs = tf.gather(tf.nn.softmax(logits), sort_indices, batch_dims=1)
cumprobs = tf.cumsum(probs, axis=-1, exclusive=True)
# The top 1 candidate always will not be masked.
# This way ensures at least 1 indices will be selected.
sort_mask = tf.cast(tf.greater(cumprobs, p), logits.dtype)
batch_indices = tf.tile(
tf.expand_dims(tf.range(tf.shape(logits)[0]), axis=-1), [1, dim])
top_p_mask = tf.scatter_nd(
tf.stack([batch_indices, sort_indices], axis=-1), sort_mask,
tf.shape(logits))
logits -= top_p_mask * 1e9
return tf.reshape(logits, [-1, seq, dim])
_function_cache = {} # For multihead_self_attention_memory_efficient
def multihead_self_attention_memory_efficient(x,
bias,
num_heads,
head_size=None,
cache=None,
epsilon=1e-6,
forget=True,
test_vars=None,
name=None):
"""Memory-efficient Multihead scaled-dot-product self-attention.
Based on Tensor2Tensor version but adds optional caching.
Returns multihead-self-attention(layer_norm(x))
Computes one attention head at a time to avoid exhausting memory.
If forget=True, then forget all forwards activations and recompute on
the backwards pass.
Args:
x: a Tensor with shape [batch, length, input_size]
bias: an attention bias tensor broadcastable to [batch, 1, length, length]
num_heads: an integer
head_size: an optional integer - defaults to input_size/num_heads
cache: Optional dict containing tensors which are the results of previous
attentions, used for fast decoding. Expects the dict to contain two
keys ('k' and 'v'), for the initial call the values for these keys
should be empty Tensors of the appropriate shape.
'k' [batch_size, 0, key_channels] 'v' [batch_size, 0, value_channels]
epsilon: a float, for layer norm
forget: a boolean - forget forwards activations and recompute on backprop
test_vars: optional tuple of variables for testing purposes
name: an optional string
Returns:
A Tensor.
"""
io_size = x.get_shape().as_list()[-1]
if head_size is None:
assert io_size % num_heads == 0
head_size = io_size / num_heads
def forward_internal(x, wqkv, wo, attention_bias, norm_scale, norm_bias):
"""Forward function."""
n = common_layers.layer_norm_compute(x, epsilon, norm_scale, norm_bias)
wqkv_split = tf.unstack(wqkv, num=num_heads)
wo_split = tf.unstack(wo, num=num_heads)
y = 0
if cache is not None:
cache_k = []
cache_v = []
for h in range(num_heads):
with tf.control_dependencies([y] if h > 0 else []):
combined = tf.nn.conv1d(n, wqkv_split[h], 1, 'SAME')
q, k, v = tf.split(combined, 3, axis=2)
if cache is not None:
k = tf.concat([cache['k'][:, h], k], axis=1)
v = tf.concat([cache['v'][:, h], v], axis=1)
cache_k.append(k)
cache_v.append(v)
o = common_attention.scaled_dot_product_attention_simple(
q, k, v, attention_bias)
y += tf.nn.conv1d(o, wo_split[h], 1, 'SAME')
if cache is not None:
cache['k'] = tf.stack(cache_k, axis=1)
cache['v'] = tf.stack(cache_v, axis=1)
return y
key = (
'multihead_self_attention_memory_efficient %s %s' % (num_heads, epsilon))
if not forget:
forward_fn = forward_internal
elif key in _function_cache:
forward_fn = _function_cache[key]
else:
@function.Defun(compiled=True)
def grad_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias, dy):
"""Custom gradient function."""
with tf.control_dependencies([dy]):
n = common_layers.layer_norm_compute(x, epsilon, norm_scale, norm_bias)
wqkv_split = tf.unstack(wqkv, num=num_heads)
wo_split = tf.unstack(wo, num=num_heads)
deps = []
dwqkvs = []
dwos = []
dn = 0
for h in range(num_heads):
with tf.control_dependencies(deps):
combined = tf.nn.conv1d(n, wqkv_split[h], 1, 'SAME')
q, k, v = tf.split(combined, 3, axis=2)
o = common_attention.scaled_dot_product_attention_simple(
q, k, v, attention_bias)
partial_y = tf.nn.conv1d(o, wo_split[h], 1, 'SAME')
pdn, dwqkvh, dwoh = tf.gradients(
ys=[partial_y],
xs=[n, wqkv_split[h], wo_split[h]],
grad_ys=[dy])
dn += pdn
dwqkvs.append(dwqkvh)
dwos.append(dwoh)
deps = [dn, dwqkvh, dwoh]
dwqkv = tf.stack(dwqkvs)
dwo = tf.stack(dwos)
with tf.control_dependencies(deps):
dx, dnorm_scale, dnorm_bias = tf.gradients(
ys=[n], xs=[x, norm_scale, norm_bias], grad_ys=[dn])
return (dx, dwqkv, dwo, tf.zeros_like(attention_bias), dnorm_scale,
dnorm_bias)
@function.Defun(
grad_func=grad_fn, compiled=True, separate_compiled_gradients=True)
def forward_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias):
return forward_internal(x, wqkv, wo, attention_bias, norm_scale,
norm_bias)
_function_cache[key] = forward_fn
if bias is not None:
bias = tf.squeeze(bias, 1)
with tf.variable_scope(name, default_name='multihead_attention', values=[x]):
if test_vars is not None:
wqkv, wo, norm_scale, norm_bias = list(test_vars)
else:
wqkv = tf.get_variable(
'wqkv', [num_heads, 1, io_size, 3 * head_size],
initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
wo = tf.get_variable(
'wo', [num_heads, 1, head_size, io_size],
initializer=tf.random_normal_initializer(
stddev=(head_size * num_heads)**-0.5))
norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
y = forward_fn(x, wqkv, wo, bias, norm_scale, norm_bias)
y.set_shape(x.get_shape()) # pytype: disable=attribute-error
return y
class TransformerEncoder(snt.AbstractModule):
"""Transformer encoder.
Sonnet Transformer encoder module as described in Vaswani et al. 2017. Uses
the Tensor2Tensor multihead_attention function for full self attention
(no masking). Layer norm is applied inside the residual path as in sparse
transformers (Child 2019).
This module expects inputs to be already embedded, and does not add position
embeddings.
"""
def __init__(self,
hidden_size=256,
fc_size=1024,
num_heads=4,
layer_norm=True,
num_layers=8,
dropout_rate=0.2,
re_zero=True,
memory_efficient=False,
name='transformer_encoder'):
"""Initializes TransformerEncoder.
Args:
hidden_size: Size of embedding vectors.
fc_size: Size of fully connected layer.
num_heads: Number of attention heads.
layer_norm: If True, apply layer normalization
num_layers: Number of Transformer blocks, where each block contains a
multi-head attention layer and a MLP.
dropout_rate: Dropout rate applied immediately after the ReLU in each
fully-connected layer.
re_zero: If True, alpha scale residuals with zero init.
memory_efficient: If True, recompute gradients for memory savings.
name: Name of variable scope
"""
super(TransformerEncoder, self).__init__(name=name)
self.hidden_size = hidden_size
self.num_heads = num_heads
self.layer_norm = layer_norm
self.fc_size = fc_size
self.num_layers = num_layers
self.dropout_rate = dropout_rate
self.re_zero = re_zero
self.memory_efficient = memory_efficient
def _build(self, inputs, is_training=False):
"""Passes inputs through Transformer encoder network.
Args:
inputs: Tensor of shape [batch_size, sequence_length, embed_size]. Zero
embeddings are masked in self-attention.
is_training: If True, dropout is applied.
Returns:
output: Tensor of shape [batch_size, sequence_length, embed_size].
"""
if is_training:
dropout_rate = self.dropout_rate
else:
dropout_rate = 0.
# Identify elements with all zeros as padding, and create bias to mask
# out padding elements in self attention.
encoder_padding = common_attention.embedding_to_padding(inputs)
encoder_self_attention_bias = (
common_attention.attention_bias_ignore_padding(encoder_padding))
x = inputs
for layer_num in range(self.num_layers):
with tf.variable_scope('layer_{}'.format(layer_num)):
# Multihead self-attention from Tensor2Tensor.
res = x
if self.memory_efficient:
res = multihead_self_attention_memory_efficient(
res,
bias=encoder_self_attention_bias,
num_heads=self.num_heads,
head_size=self.hidden_size // self.num_heads,
forget=True if is_training else False,
name='self_attention'
)
else:
if self.layer_norm:
res = common_layers.layer_norm(res, name='self_attention')
res = common_attention.multihead_attention(
res,
memory_antecedent=None,
bias=encoder_self_attention_bias,
total_key_depth=self.hidden_size,
total_value_depth=self.hidden_size,
output_depth=self.hidden_size,
num_heads=self.num_heads,
dropout_rate=0.,
make_image_summary=False,
name='self_attention')
if self.re_zero:
res *= tf.get_variable('self_attention/alpha', initializer=0.)
if dropout_rate:
res = tf.nn.dropout(res, rate=dropout_rate)
x += res
# MLP
res = x
if self.layer_norm:
res = common_layers.layer_norm(res, name='fc')
res = tf.layers.dense(
res, self.fc_size, activation=tf.nn.relu, name='fc_1')
res = tf.layers.dense(res, self.hidden_size, name='fc_2')
if self.re_zero:
res *= tf.get_variable('fc/alpha', initializer=0.)
if dropout_rate:
res = tf.nn.dropout(res, rate=dropout_rate)
x += res
if self.layer_norm:
output = common_layers.layer_norm(x, name='output')
else:
output = x
return output
class TransformerDecoder(snt.AbstractModule):
"""Transformer decoder.
Sonnet Transformer decoder module as described in Vaswani et al. 2017. Uses
the Tensor2Tensor multihead_attention function for masked self attention, and
non-masked cross attention attention. Layer norm is applied inside the
residual path as in sparse transformers (Child 2019).
This module expects inputs to be already embedded, and does not
add position embeddings.
"""
def __init__(self,
hidden_size=256,
fc_size=1024,
num_heads=4,
layer_norm=True,
num_layers=8,
dropout_rate=0.2,
re_zero=True,
memory_efficient=False,
name='transformer_decoder'):
"""Initializes TransformerDecoder.
Args:
hidden_size: Size of embedding vectors.
fc_size: Size of fully connected layer.
num_heads: Number of attention heads.
layer_norm: If True, apply layer normalization. If mem_efficient_attention
is True, then layer norm is always applied.
num_layers: Number of Transformer blocks, where each block contains a
multi-head attention layer and a MLP.
dropout_rate: Dropout rate applied immediately after the ReLU in each
fully-connected layer.
re_zero: If True, alpha scale residuals with zero init.
memory_efficient: If True, recompute gradients for memory savings.
name: Name of variable scope
"""
super(TransformerDecoder, self).__init__(name=name)
self.hidden_size = hidden_size
self.num_heads = num_heads
self.layer_norm = layer_norm
self.fc_size = fc_size
self.num_layers = num_layers
self.dropout_rate = dropout_rate
self.re_zero = re_zero
self.memory_efficient = memory_efficient
def _build(self,
inputs,
sequential_context_embeddings=None,
is_training=False,
cache=None):
"""Passes inputs through Transformer decoder network.
Args:
inputs: Tensor of shape [batch_size, sequence_length, embed_size]. Zero
embeddings are masked in self-attention.
sequential_context_embeddings: Optional tensor with global context
(e.g image embeddings) of shape
[batch_size, context_seq_length, context_embed_size].
is_training: If True, dropout is applied.
cache: Optional dict containing tensors which are the results of previous
attentions, used for fast decoding. Expects the dict to contain two
keys ('k' and 'v'), for the initial call the values for these keys
should be empty Tensors of the appropriate shape.
'k' [batch_size, 0, key_channels] 'v' [batch_size, 0, value_channels]
Returns:
output: Tensor of shape [batch_size, sequence_length, embed_size].
"""
if is_training:
dropout_rate = self.dropout_rate
else:
dropout_rate = 0.
# create bias to mask future elements for causal self-attention.
seq_length = tf.shape(inputs)[1]
decoder_self_attention_bias = common_attention.attention_bias_lower_triangle(
seq_length)
# If using sequential_context, identify elements with all zeros as padding,
# and create bias to mask out padding elements in self attention.
if sequential_context_embeddings is not None:
encoder_padding = common_attention.embedding_to_padding(
sequential_context_embeddings)
encoder_decoder_attention_bias = (
common_attention.attention_bias_ignore_padding(encoder_padding))
x = inputs
for layer_num in range(self.num_layers):
with tf.variable_scope('layer_{}'.format(layer_num)):
# If using cached decoding, access cache for current layer, and create
# bias that enables un-masked attention into the cache
if cache is not None:
layer_cache = cache[layer_num]
layer_decoder_bias = tf.zeros([1, 1, 1, 1])
# Otherwise use standard masked bias
else:
layer_cache = None
layer_decoder_bias = decoder_self_attention_bias
# Multihead self-attention from Tensor2Tensor.
res = x
if self.memory_efficient:
res = multihead_self_attention_memory_efficient(
res,
bias=layer_decoder_bias,
cache=layer_cache,
num_heads=self.num_heads,
head_size=self.hidden_size // self.num_heads,
forget=True if is_training else False,
name='self_attention'
)
else:
if self.layer_norm:
res = common_layers.layer_norm(res, name='self_attention')
res = common_attention.multihead_attention(
res,
memory_antecedent=None,
bias=layer_decoder_bias,
total_key_depth=self.hidden_size,
total_value_depth=self.hidden_size,
output_depth=self.hidden_size,
num_heads=self.num_heads,
cache=layer_cache,
dropout_rate=0.,
make_image_summary=False,
name='self_attention')
if self.re_zero:
res *= tf.get_variable('self_attention/alpha', initializer=0.)
if dropout_rate:
res = tf.nn.dropout(res, rate=dropout_rate)
x += res
# Optional cross attention into sequential context
if sequential_context_embeddings is not None:
res = x
if self.layer_norm:
res = common_layers.layer_norm(res, name='cross_attention')
res = common_attention.multihead_attention(
res,
memory_antecedent=sequential_context_embeddings,
bias=encoder_decoder_attention_bias,
total_key_depth=self.hidden_size,
total_value_depth=self.hidden_size,
output_depth=self.hidden_size,
num_heads=self.num_heads,
dropout_rate=0.,
make_image_summary=False,
name='cross_attention')
if self.re_zero:
res *= tf.get_variable('cross_attention/alpha', initializer=0.)
if dropout_rate:
res = tf.nn.dropout(res, rate=dropout_rate)
x += res
# FC layers
res = x
if self.layer_norm:
res = common_layers.layer_norm(res, name='fc')
res = tf.layers.dense(
res, self.fc_size, activation=tf.nn.relu, name='fc_1')
res = tf.layers.dense(res, self.hidden_size, name='fc_2')
if self.re_zero:
res *= tf.get_variable('fc/alpha', initializer=0.)
if dropout_rate:
res = tf.nn.dropout(res, rate=dropout_rate)
x += res
if self.layer_norm:
output = common_layers.layer_norm(x, name='output')
else:
output = x
return output
def create_init_cache(self, batch_size):
"""Creates empty cache dictionary for use in fast decoding."""
def compute_cache_shape_invariants(tensor):
"""Helper function to get dynamic shapes for cache tensors."""
shape_list = tensor.shape.as_list()
if len(shape_list) == 4:
return tf.TensorShape(
[shape_list[0], shape_list[1], None, shape_list[3]])
elif len(shape_list) == 3:
return tf.TensorShape([shape_list[0], None, shape_list[2]])
# Build cache
k = common_attention.split_heads(
tf.zeros([batch_size, 0, self.hidden_size]), self.num_heads)
v = common_attention.split_heads(
tf.zeros([batch_size, 0, self.hidden_size]), self.num_heads)
cache = [{'k': k, 'v': v} for _ in range(self.num_layers)]
shape_invariants = tf.nest.map_structure(
compute_cache_shape_invariants, cache)
return cache, shape_invariants
def conv_residual_block(inputs,
output_channels=None,
downsample=False,
kernel_size=3,
re_zero=True,
dropout_rate=0.,
name='conv_residual_block'):
"""Convolutional block with residual connections for 2D or 3D inputs.
Args:
inputs: Input tensor of shape [batch_size, height, width, channels] or
[batch_size, height, width, depth, channels].
output_channels: Number of output channels.
downsample: If True, downsample by 1/2 in this block.
kernel_size: Spatial size of convolutional kernels.
re_zero: If True, alpha scale residuals with zero init.
dropout_rate: Dropout rate applied after second ReLU in residual path.
name: Name for variable scope.
Returns:
outputs: Output tensor of shape [batch_size, height, width, output_channels]
or [batch_size, height, width, depth, output_channels].
"""
with tf.variable_scope(name):
input_shape = inputs.get_shape().as_list()
num_dims = len(input_shape) - 2
if num_dims == 2:
conv = tf.layers.conv2d
elif num_dims == 3:
conv = tf.layers.conv3d
input_channels = input_shape[-1]
if output_channels is None:
output_channels = input_channels
if downsample:
shortcut = conv(
inputs,
filters=output_channels,
strides=2,
kernel_size=kernel_size,
padding='same',
name='conv_shortcut')
else:
shortcut = inputs
res = inputs
res = tf.nn.relu(res)
res = conv(
res, filters=input_channels, kernel_size=kernel_size, padding='same',
name='conv_1')
res = tf.nn.relu(res)
if dropout_rate:
res = tf.nn.dropout(res, rate=dropout_rate)
if downsample:
out_strides = 2
else:
out_strides = 1
res = conv(
res,
filters=output_channels,
kernel_size=kernel_size,
padding='same',
strides=out_strides,
name='conv_2')
if re_zero:
res *= tf.get_variable('alpha', initializer=0.)
return shortcut + res
class ResNet(snt.AbstractModule):
"""ResNet architecture for 2D image or 3D voxel inputs."""
def __init__(self,
num_dims,
hidden_sizes=(64, 256),
num_blocks=(2, 2),
dropout_rate=0.1,
re_zero=True,
name='res_net'):
"""Initializes ResNet.
Args:
num_dims: Number of spatial dimensions. 2 for images or 3 for voxels.
hidden_sizes: Sizes of hidden layers in resnet blocks.
num_blocks: Number of resnet blocks at each size.
dropout_rate: Dropout rate applied immediately after the ReLU in each
fully-connected layer.
re_zero: If True, alpha scale residuals with zero init.
name: Name of variable scope
"""
super(ResNet, self).__init__(name=name)
self.num_dims = num_dims
self.hidden_sizes = hidden_sizes
self.num_blocks = num_blocks
self.dropout_rate = dropout_rate
self.re_zero = re_zero
def _build(self, inputs, is_training=False):
"""Passes inputs through resnet.
Args:
inputs: Tensor of shape [batch_size, height, width, channels] or
[batch_size, height, width, depth, channels].
is_training: If True, dropout is applied.
Returns:
output: Tensor of shape [batch_size, height, width, depth, output_size].
"""
if is_training:
dropout_rate = self.dropout_rate
else:
dropout_rate = 0.
# Initial projection with large kernel as in original resnet architecture
if self.num_dims == 3:
conv = tf.layers.conv3d
elif self.num_dims == 2:
conv = tf.layers.conv2d
x = conv(
inputs,
filters=self.hidden_sizes[0],
kernel_size=7,
strides=2,
padding='same',
name='conv_input')
if self.num_dims == 2:
x = tf.layers.max_pooling2d(
x, strides=2, pool_size=3, padding='same', name='pool_input')
for d, (hidden_size,
blocks) in enumerate(zip(self.hidden_sizes, self.num_blocks)):
with tf.variable_scope('resolution_{}'.format(d)):
# Downsample at the start of each collection of blocks
x = conv_residual_block(
x,
downsample=False if d == 0 else True,
dropout_rate=dropout_rate,
output_channels=hidden_size,
re_zero=self.re_zero,
name='block_1_downsample')
for i in range(blocks - 1):
x = conv_residual_block(
x,
dropout_rate=dropout_rate,
output_channels=hidden_size,
re_zero=self.re_zero,
name='block_{}'.format(i + 2))
return x
class VertexModel(snt.AbstractModule):
"""Autoregressive generative model of quantized mesh vertices.
Operates on flattened vertex sequences with a stopping token:
[z_0, y_0, x_0, z_1, y_1, x_1, ..., z_n, y_n, z_n, STOP]
Input vertex coordinates are embedded and tagged with learned coordinate and
position indicators. A transformer decoder outputs logits for a quantized
vertex distribution.
"""
def __init__(self,
decoder_config,
quantization_bits,
class_conditional=False,
num_classes=55,
max_num_input_verts=2500,
use_discrete_embeddings=True,
name='vertex_model'):
"""Initializes VertexModel.
Args:
decoder_config: Dictionary with TransformerDecoder config
quantization_bits: Number of quantization used in mesh preprocessing.
class_conditional: If True, then condition on learned class embeddings.
num_classes: Number of classes to condition on.
max_num_input_verts: Maximum number of vertices. Used for learned position
embeddings.
use_discrete_embeddings: If True, use discrete rather than continuous
vertex embeddings.
name: Name of variable scope
"""
super(VertexModel, self).__init__(name=name)
self.embedding_dim = decoder_config['hidden_size']
self.class_conditional = class_conditional
self.num_classes = num_classes
self.max_num_input_verts = max_num_input_verts
self.quantization_bits = quantization_bits
self.use_discrete_embeddings = use_discrete_embeddings
with self._enter_variable_scope():
self.decoder = TransformerDecoder(**decoder_config)
@snt.reuse_variables
def _embed_class_label(self, labels):
"""Embeds class label with learned embedding matrix."""
init_dict = {'embeddings': tf.glorot_uniform_initializer}
return snt.Embed(
vocab_size=self.num_classes,
embed_dim=self.embedding_dim,
initializers=init_dict,
densify_gradients=True,
name='class_label')(labels)
@snt.reuse_variables
def _prepare_context(self, context, is_training=False):
"""Prepare class label context."""
if self.class_conditional:
global_context_embedding = self._embed_class_label(context['class_label'])
else:
global_context_embedding = None
return global_context_embedding, None
@snt.reuse_variables
def _embed_inputs(self, vertices, global_context_embedding=None):
"""Embeds flat vertices and adds position and coordinate information."""
# Dequantize inputs and get shapes
input_shape = tf.shape(vertices)
batch_size, seq_length = input_shape[0], input_shape[1]
# Coord indicators (x, y, z)
coord_embeddings = snt.Embed(
vocab_size=3,
embed_dim=self.embedding_dim,
initializers={'embeddings': tf.glorot_uniform_initializer},
densify_gradients=True,
name='coord_embeddings')(tf.mod(tf.range(seq_length), 3))
# Position embeddings
pos_embeddings = snt.Embed(
vocab_size=self.max_num_input_verts,
embed_dim=self.embedding_dim,
initializers={'embeddings': tf.glorot_uniform_initializer},
densify_gradients=True,
name='coord_embeddings')(tf.floordiv(tf.range(seq_length), 3))
# Discrete vertex value embeddings
if self.use_discrete_embeddings:
vert_embeddings = snt.Embed(
vocab_size=2**self.quantization_bits + 1,
embed_dim=self.embedding_dim,
initializers={'embeddings': tf.glorot_uniform_initializer},
densify_gradients=True,
name='value_embeddings')(vertices)
# Continuous vertex value embeddings
else:
vert_embeddings = tf.layers.dense(
dequantize_verts(vertices[..., None], self.quantization_bits),
self.embedding_dim,
use_bias=True,
name='value_embeddings')
# Step zero embeddings
if global_context_embedding is None:
zero_embed = tf.get_variable(
'embed_zero', shape=[1, 1, self.embedding_dim])
zero_embed_tiled = tf.tile(zero_embed, [batch_size, 1, 1])
else:
zero_embed_tiled = global_context_embedding[:, None]
# Aggregate embeddings
embeddings = vert_embeddings + (coord_embeddings + pos_embeddings)[None]
embeddings = tf.concat([zero_embed_tiled, embeddings], axis=1)
return embeddings
@snt.reuse_variables
def _project_to_logits(self, inputs):
"""Projects transformer outputs to logits for predictive distribution."""
return tf.layers.dense(
inputs,
2**self.quantization_bits + 1, # + 1 for stopping token
use_bias=True,
kernel_initializer=tf.zeros_initializer(),
name='project_to_logits')
@snt.reuse_variables
def _create_dist(self,
vertices,
global_context_embedding=None,
sequential_context_embeddings=None,
temperature=1.,
top_k=0,
top_p=1.,
is_training=False,
cache=None):
"""Outputs categorical dist for quantized vertex coordinates."""
# Embed inputs
decoder_inputs = self._embed_inputs(vertices, global_context_embedding)
if cache is not None:
decoder_inputs = decoder_inputs[:, -1:]
# pass through decoder
outputs = self.decoder(
decoder_inputs, cache=cache,
sequential_context_embeddings=sequential_context_embeddings,
is_training=is_training)
# Get logits and optionally process for sampling
logits = self._project_to_logits(outputs)
logits /= temperature
logits = top_k_logits(logits, top_k)
logits = top_p_logits(logits, top_p)
cat_dist = tfd.Categorical(logits=logits)
return cat_dist
def _build(self, batch, is_training=False):
"""Pass batch through vertex model and get log probabilities under model.
Args:
batch: Dictionary containing:
'vertices_flat': int32 vertex tensors of shape [batch_size, seq_length].
is_training: If True, use dropout.
Returns:
pred_dist: tfd.Categorical predictive distribution with batch shape
[batch_size, seq_length].
"""
global_context, seq_context = self._prepare_context(
batch, is_training=is_training)
pred_dist = self._create_dist(
batch['vertices_flat'][:, :-1], # Last element not used for preds
global_context_embedding=global_context,
sequential_context_embeddings=seq_context,
is_training=is_training)
return pred_dist
def sample(self,
num_samples,
context=None,
max_sample_length=None,
temperature=1.,
top_k=0,
top_p=1.,
recenter_verts=True,
only_return_complete=True):
"""Autoregressive sampling with caching.
Args:
num_samples: Number of samples to produce.
context: Dictionary of context, such as class labels. See _prepare_context
for details.
max_sample_length: Maximum length of sampled vertex sequences. Sequences
that do not complete are truncated.
temperature: Scalar softmax temperature > 0.
top_k: Number of tokens to keep for top-k sampling.
top_p: Proportion of probability mass to keep for top-p sampling.
recenter_verts: If True, center vertex samples around origin. This should
be used if model is trained using shift augmentations.
only_return_complete: If True, only return completed samples. Otherwise
return all samples along with completed indicator.
Returns:
outputs: Output dictionary with fields:
'completed': Boolean tensor of shape [num_samples]. If True then
corresponding sample completed within max_sample_length.
'vertices': Tensor of samples with shape [num_samples, num_verts, 3].
'num_vertices': Tensor indicating number of vertices for each example
in padded vertex samples.
'vertices_mask': Tensor of shape [num_samples, num_verts] that masks
corresponding invalid elements in 'vertices'.
"""
# Obtain context for decoder
global_context, seq_context = self._prepare_context(
context, is_training=False)
# num_samples is the minimum value of num_samples and the batch size of
# context inputs (if present).
if global_context is not None:
num_samples = tf.minimum(num_samples, tf.shape(global_context)[0])
global_context = global_context[:num_samples]
if seq_context is not None:
seq_context = seq_context[:num_samples]
elif seq_context is not None:
num_samples = tf.minimum(num_samples, tf.shape(seq_context)[0])
seq_context = seq_context[:num_samples]
def _loop_body(i, samples, cache):
"""While-loop body for autoregression calculation."""
cat_dist = self._create_dist(
samples,
global_context_embedding=global_context,
sequential_context_embeddings=seq_context,
cache=cache,
temperature=temperature,
top_k=top_k,
top_p=top_p)
next_sample = cat_dist.sample()
samples = tf.concat([samples, next_sample], axis=1)
return i + 1, samples, cache
def _stopping_cond(i, samples, cache):
"""Stopping condition for sampling while-loop."""
del i, cache # Unused
return tf.reduce_any(tf.reduce_all(tf.not_equal(samples, 0), axis=-1))
# Initial values for loop variables
samples = tf.zeros([num_samples, 0], dtype=tf.int32)
max_sample_length = max_sample_length or self.max_num_input_verts
cache, cache_shape_invariants = self.decoder.create_init_cache(num_samples)
_, v, _ = tf.while_loop(
cond=_stopping_cond,
body=_loop_body,
loop_vars=(0, samples, cache),
shape_invariants=(tf.TensorShape([]), tf.TensorShape([None, None]),
cache_shape_invariants),
maximum_iterations=max_sample_length * 3 + 1,
back_prop=False,
parallel_iterations=1)
# Check if samples completed. Samples are complete if the stopping token
# is produced.
completed = tf.reduce_any(tf.equal(v, 0), axis=-1)
# Get the number of vertices in the sample. This requires finding the
# index of the stopping token. For complete samples use to argmax to get
# first nonzero index.
stop_index_completed = tf.argmax(
tf.cast(tf.equal(v, 0), tf.int32), axis=-1, output_type=tf.int32)
# For incomplete samples the stopping index is just the maximum index.
stop_index_incomplete = (
max_sample_length * 3 * tf.ones_like(stop_index_completed))
stop_index = tf.where(
completed, stop_index_completed, stop_index_incomplete)
num_vertices = tf.floordiv(stop_index, 3)
# Convert to 3D vertices by reshaping and re-ordering x -> y -> z
v = v[:, :(tf.reduce_max(num_vertices) * 3)] - 1
verts_dequantized = dequantize_verts(v, self.quantization_bits)
vertices = tf.reshape(verts_dequantized, [num_samples, -1, 3])
vertices = tf.stack(
[vertices[..., 2], vertices[..., 1], vertices[..., 0]], axis=-1)
# Pad samples to max sample length. This is required in order to concatenate
# Samples across different replicator instances. Pad with stopping tokens
# for incomplete samples.
pad_size = max_sample_length - tf.shape(vertices)[1]
vertices = tf.pad(vertices, [[0, 0], [0, pad_size], [0, 0]])
# 3D Vertex mask
vertices_mask = tf.cast(
tf.range(max_sample_length)[None] < num_vertices[:, None], tf.float32)
if recenter_verts:
vert_max = tf.reduce_max(
vertices - 1e10 * (1. - vertices_mask)[..., None], axis=1,
keepdims=True)
vert_min = tf.reduce_min(
vertices + 1e10 * (1. - vertices_mask)[..., None], axis=1,
keepdims=True)
vert_centers = 0.5 * (vert_max + vert_min)
vertices -= vert_centers
vertices *= vertices_mask[..., None]
if only_return_complete:
vertices = tf.boolean_mask(vertices, completed)
num_vertices = tf.boolean_mask(num_vertices, completed)
vertices_mask = tf.boolean_mask(vertices_mask, completed)
completed = tf.boolean_mask(completed, completed)
# Outputs
outputs = {
'completed': completed,
'vertices': vertices,
'num_vertices': num_vertices,
'vertices_mask': vertices_mask,
}
return outputs
class ImageToVertexModel(VertexModel):
"""Generative model of quantized mesh vertices with image conditioning.
Operates on flattened vertex sequences with a stopping token:
[z_0, y_0, x_0, z_1, y_1, x_1, ..., z_n, y_n, z_n, STOP]
Input vertex coordinates are embedded and tagged with learned coordinate and
position indicators. A transformer decoder outputs logits for a quantized
vertex distribution. Image inputs are encoded and used to condition the
vertex decoder.
"""
def __init__(self,
res_net_config,
decoder_config,
quantization_bits,
use_discrete_embeddings=True,
max_num_input_verts=2500,
name='image_to_vertex_model'):
"""Initializes VoxelToVertexModel.
Args:
res_net_config: Dictionary with ResNet config.
decoder_config: Dictionary with TransformerDecoder config.
quantization_bits: Number of quantization used in mesh preprocessing.
use_discrete_embeddings: If True, use discrete rather than continuous
vertex embeddings.
max_num_input_verts: Maximum number of vertices. Used for learned position
embeddings.
name: Name of variable scope
"""
super(ImageToVertexModel, self).__init__(
decoder_config=decoder_config,
quantization_bits=quantization_bits,
max_num_input_verts=max_num_input_verts,
use_discrete_embeddings=use_discrete_embeddings,
name=name)
with self._enter_variable_scope():
self.res_net = ResNet(num_dims=2, **res_net_config)
@snt.reuse_variables
def _prepare_context(self, context, is_training=False):
# Pass images through encoder
image_embeddings = self.res_net(
context['image'] - 0.5, is_training=is_training)
# Add 2D coordinate grid embedding
processed_image_resolution = tf.shape(image_embeddings)[1]
x = tf.linspace(-1., 1., processed_image_resolution)
image_coords = tf.stack(tf.meshgrid(x, x), axis=-1)
image_coord_embeddings = tf.layers.dense(
image_coords,
self.embedding_dim,
use_bias=True,
name='image_coord_embeddings')
image_embeddings += image_coord_embeddings[None]
# Reshape spatial grid to sequence
batch_size = tf.shape(image_embeddings)[0]
sequential_context_embedding = tf.reshape(
image_embeddings, [batch_size, -1, self.embedding_dim])
return None, sequential_context_embedding
class VoxelToVertexModel(VertexModel):
"""Generative model of quantized mesh vertices with voxel conditioning.
Operates on flattened vertex sequences with a stopping token:
[z_0, y_0, x_0, z_1, y_1, x_1, ..., z_n, y_n, z_n, STOP]
Input vertex coordinates are embedded and tagged with learned coordinate and
position indicators. A transformer decoder outputs logits for a quantized
vertex distribution. Image inputs are encoded and used to condition the
vertex decoder.
"""
def __init__(self,
res_net_config,
decoder_config,
quantization_bits,
use_discrete_embeddings=True,
max_num_input_verts=2500,
name='voxel_to_vertex_model'):
"""Initializes VoxelToVertexModel.
Args:
res_net_config: Dictionary with ResNet config.
decoder_config: Dictionary with TransformerDecoder config.
quantization_bits: Integer number of bits used for vertex quantization.
use_discrete_embeddings: If True, use discrete rather than continuous
vertex embeddings.
max_num_input_verts: Maximum number of vertices. Used for learned position
embeddings.
name: Name of variable scope
"""
super(VoxelToVertexModel, self).__init__(
decoder_config=decoder_config,
quantization_bits=quantization_bits,
max_num_input_verts=max_num_input_verts,
use_discrete_embeddings=use_discrete_embeddings,
name=name)
with self._enter_variable_scope():
self.res_net = ResNet(num_dims=3, **res_net_config)
@snt.reuse_variables
def _prepare_context(self, context, is_training=False):
# Embed binary input voxels
voxel_embeddings = snt.Embed(
vocab_size=2,
embed_dim=self.pre_embed_dim,
initializers={'embeddings': tf.glorot_uniform_initializer},
densify_gradients=True,
name='voxel_embeddings')(context['voxels'])
# Pass embedded voxels through voxel encoder
voxel_embeddings = self.res_net(
voxel_embeddings, is_training=is_training)
# Add 3D coordinate grid embedding
processed_voxel_resolution = tf.shape(voxel_embeddings)[1]
x = tf.linspace(-1., 1., processed_voxel_resolution)
voxel_coords = tf.stack(tf.meshgrid(x, x, x), axis=-1)
voxel_coord_embeddings = tf.layers.dense(
voxel_coords,
self.embedding_dim,
use_bias=True,
name='voxel_coord_embeddings')
voxel_embeddings += voxel_coord_embeddings[None]
# Reshape spatial grid to sequence
batch_size = tf.shape(voxel_embeddings)[0]
sequential_context_embedding = tf.reshape(
voxel_embeddings, [batch_size, -1, self.embedding_dim])
return None, sequential_context_embedding
class FaceModel(snt.AbstractModule):
"""Autoregressive generative model of n-gon meshes.
Operates on sets of input vertices as well as flattened face sequences with
new face and stopping tokens:
[f_0^0, f_0^1, f_0^2, NEW, f_1^0, f_1^1, ..., STOP]
Input vertices are encoded using a Transformer encoder.
Input face sequences are embedded and tagged with learned position indicators,
as well as their corresponding vertex embeddings. A transformer decoder
outputs a pointer which is compared to each vertex embedding to obtain a
distribution over vertex indices.
"""
def __init__(self,
encoder_config,
decoder_config,
class_conditional=True,
num_classes=55,
decoder_cross_attention=True,
use_discrete_vertex_embeddings=True,
quantization_bits=8,
max_seq_length=5000,
name='face_model'):
"""Initializes FaceModel.
Args:
encoder_config: Dictionary with TransformerEncoder config.
decoder_config: Dictionary with TransformerDecoder config.
class_conditional: If True, then condition on learned class embeddings.
num_classes: Number of classes to condition on.
decoder_cross_attention: If True, the use cross attention from decoder
querys into encoder outputs.
use_discrete_vertex_embeddings: If True, use discrete vertex embeddings.
quantization_bits: Number of quantization bits for discrete vertex
embeddings.
max_seq_length: Maximum face sequence length. Used for learned position
embeddings.
name: Name of variable scope
"""
super(FaceModel, self).__init__(name=name)
self.embedding_dim = decoder_config['hidden_size']
self.class_conditional = class_conditional
self.num_classes = num_classes
self.max_seq_length = max_seq_length
self.decoder_cross_attention = decoder_cross_attention
self.use_discrete_vertex_embeddings = use_discrete_vertex_embeddings
self.quantization_bits = quantization_bits
with self._enter_variable_scope():
self.decoder = TransformerDecoder(**decoder_config)
self.encoder = TransformerEncoder(**encoder_config)
@snt.reuse_variables
def _embed_class_label(self, labels):
"""Embeds class label with learned embedding matrix."""
init_dict = {'embeddings': tf.glorot_uniform_initializer}
return snt.Embed(
vocab_size=self.num_classes,
embed_dim=self.embedding_dim,
initializers=init_dict,
densify_gradients=True,
name='class_label')(labels)
@snt.reuse_variables
def _prepare_context(self, context, is_training=False):
"""Prepare class label and vertex context."""
if self.class_conditional:
global_context_embedding = self._embed_class_label(context['class_label'])
else:
global_context_embedding = None
vertex_embeddings = self._embed_vertices(
context['vertices'], context['vertices_mask'],
is_training=is_training)
if self.decoder_cross_attention:
sequential_context_embeddings = (
vertex_embeddings *
tf.pad(context['vertices_mask'], [[0, 0], [2, 0]],
constant_values=1)[..., None])
else:
sequential_context_embeddings = None
return (vertex_embeddings, global_context_embedding,
sequential_context_embeddings)
@snt.reuse_variables
def _embed_vertices(self, vertices, vertices_mask, is_training=False):
"""Embeds vertices with transformer encoder."""
# num_verts = tf.shape(vertices)[1]
if self.use_discrete_vertex_embeddings:
vertex_embeddings = 0.
verts_quantized = quantize_verts(vertices, self.quantization_bits)
for c in range(3):
vertex_embeddings += snt.Embed(
vocab_size=256,
embed_dim=self.embedding_dim,
initializers={'embeddings': tf.glorot_uniform_initializer},
densify_gradients=True,
name='coord_{}'.format(c))(verts_quantized[..., c])
else:
vertex_embeddings = tf.layers.dense(
vertices, self.embedding_dim, use_bias=True, name='vertex_embeddings')
vertex_embeddings *= vertices_mask[..., None]
# Pad vertex embeddings with learned embeddings for stopping and new face
# tokens
stopping_embeddings = tf.get_variable(
'stopping_embeddings', shape=[1, 2, self.embedding_dim])
stopping_embeddings = tf.tile(stopping_embeddings,
[tf.shape(vertices)[0], 1, 1])
vertex_embeddings = tf.concat(
[stopping_embeddings, vertex_embeddings], axis=1)
# Pass through Transformer encoder
vertex_embeddings = self.encoder(vertex_embeddings, is_training=is_training)
return vertex_embeddings
@snt.reuse_variables
def _embed_inputs(self, faces_long, vertex_embeddings,
global_context_embedding=None):
"""Embeds face sequences and adds within and between face positions."""
# Face value embeddings are gathered vertex embeddings
face_embeddings = tf.gather(vertex_embeddings, faces_long, batch_dims=1)
# Position embeddings
pos_embeddings = snt.Embed(
vocab_size=self.max_seq_length,
embed_dim=self.embedding_dim,
initializers={'embeddings': tf.glorot_uniform_initializer},
densify_gradients=True,
name='coord_embeddings')(tf.range(tf.shape(faces_long)[1]))
# Step zero embeddings
batch_size = tf.shape(face_embeddings)[0]
if global_context_embedding is None:
zero_embed = tf.get_variable(
'embed_zero', shape=[1, 1, self.embedding_dim])
zero_embed_tiled = tf.tile(zero_embed, [batch_size, 1, 1])
else:
zero_embed_tiled = global_context_embedding[:, None]
# Aggregate embeddings
embeddings = face_embeddings + pos_embeddings[None]
embeddings = tf.concat([zero_embed_tiled, embeddings], axis=1)
return embeddings
@snt.reuse_variables
def _project_to_pointers(self, inputs):
"""Projects transformer outputs to pointer vectors."""
return tf.layers.dense(
inputs,
self.embedding_dim,
use_bias=True,
kernel_initializer=tf.zeros_initializer(),
name='project_to_pointers'
)
@snt.reuse_variables
def _create_dist(self,
vertex_embeddings,
vertices_mask,
faces_long,
global_context_embedding=None,
sequential_context_embeddings=None,
temperature=1.,
top_k=0,
top_p=1.,
is_training=False,
cache=None):
"""Outputs categorical dist for vertex indices."""
# Embed inputs
decoder_inputs = self._embed_inputs(
faces_long, vertex_embeddings, global_context_embedding)
# Pass through Transformer decoder
if cache is not None:
decoder_inputs = decoder_inputs[:, -1:]
decoder_outputs = self.decoder(
decoder_inputs,
cache=cache,
sequential_context_embeddings=sequential_context_embeddings,
is_training=is_training)
# Get pointers
pred_pointers = self._project_to_pointers(decoder_outputs)
# Get logits and mask
logits = tf.matmul(pred_pointers, vertex_embeddings, transpose_b=True)
logits /= tf.sqrt(float(self.embedding_dim))
f_verts_mask = tf.pad(
vertices_mask, [[0, 0], [2, 0]], constant_values=1.)[:, None]
logits *= f_verts_mask
logits -= (1. - f_verts_mask) * 1e9
logits /= temperature
logits = top_k_logits(logits, top_k)
logits = top_p_logits(logits, top_p)
return tfd.Categorical(logits=logits)
def _build(self, batch, is_training=False):
"""Pass batch through face model and get log probabilities.
Args:
batch: Dictionary containing:
'vertices_dequantized': Tensor of shape [batch_size, num_vertices, 3].
'faces': int32 tensor of shape [batch_size, seq_length] with flattened
faces.
'vertices_mask': float32 tensor with shape
[batch_size, num_vertices] that masks padded elements in 'vertices'.
is_training: If True, use dropout.
Returns:
pred_dist: tfd.Categorical predictive distribution with batch shape
[batch_size, seq_length].
"""
vertex_embeddings, global_context, seq_context = self._prepare_context(
batch, is_training=is_training)
pred_dist = self._create_dist(
vertex_embeddings,
batch['vertices_mask'],
batch['faces'][:, :-1],
global_context_embedding=global_context,
sequential_context_embeddings=seq_context,
is_training=is_training)
return pred_dist
def sample(self,
context,
max_sample_length=None,
temperature=1.,
top_k=0,
top_p=1.,
only_return_complete=True):
"""Sample from face model using caching.
Args:
context: Dictionary of context, including 'vertices' and 'vertices_mask'.
See _prepare_context for details.
max_sample_length: Maximum length of sampled vertex sequences. Sequences
that do not complete are truncated.
temperature: Scalar softmax temperature > 0.
top_k: Number of tokens to keep for top-k sampling.
top_p: Proportion of probability mass to keep for top-p sampling.
only_return_complete: If True, only return completed samples. Otherwise
return all samples along with completed indicator.
Returns:
outputs: Output dictionary with fields:
'completed': Boolean tensor of shape [num_samples]. If True then
corresponding sample completed within max_sample_length.
'faces': Tensor of samples with shape [num_samples, num_verts, 3].
'num_face_indices': Tensor indicating number of vertices for each
example in padded vertex samples.
"""
vertex_embeddings, global_context, seq_context = self._prepare_context(
context, is_training=False)
num_samples = tf.shape(vertex_embeddings)[0]
def _loop_body(i, samples, cache):
"""While-loop body for autoregression calculation."""
pred_dist = self._create_dist(
vertex_embeddings,
context['vertices_mask'],
samples,
global_context_embedding=global_context,
sequential_context_embeddings=seq_context,
cache=cache,
temperature=temperature,
top_k=top_k,
top_p=top_p)
next_sample = pred_dist.sample()[:, -1:]
samples = tf.concat([samples, next_sample], axis=1)
return i + 1, samples, cache
def _stopping_cond(i, samples, cache):
"""Stopping conditions for autoregressive calculation."""
del i, cache # Unused
return tf.reduce_any(tf.reduce_all(tf.not_equal(samples, 0), axis=-1))
# While loop sampling with caching
samples = tf.zeros([num_samples, 0], dtype=tf.int32)
max_sample_length = max_sample_length or self.max_seq_length
cache, cache_shape_invariants = self.decoder.create_init_cache(num_samples)
_, f, _ = tf.while_loop(
cond=_stopping_cond,
body=_loop_body,
loop_vars=(0, samples, cache),
shape_invariants=(tf.TensorShape([]), tf.TensorShape([None, None]),
cache_shape_invariants),
back_prop=False,
parallel_iterations=1,
maximum_iterations=max_sample_length)
# Record completed samples
complete_samples = tf.reduce_any(tf.equal(f, 0), axis=-1)
# Find number of faces
sample_length = tf.shape(f)[-1]
# Get largest new face (1) index as stopping point for incomplete samples.
max_one_ind = tf.reduce_max(
tf.range(sample_length)[None] * tf.cast(tf.equal(f, 1), tf.int32),
axis=-1)
zero_inds = tf.cast(
tf.argmax(tf.cast(tf.equal(f, 0), tf.int32), axis=-1), tf.int32)
num_face_indices = tf.where(complete_samples, zero_inds, max_one_ind) + 1
# Mask faces beyond stopping token with zeros
# This mask has a -1 in order to replace the last new face token with zero
faces_mask = tf.cast(
tf.range(sample_length)[None] < num_face_indices[:, None] - 1, tf.int32)
f *= faces_mask
# This is the real mask
faces_mask = tf.cast(
tf.range(sample_length)[None] < num_face_indices[:, None], tf.int32)
# Pad to maximum size with zeros
pad_size = max_sample_length - sample_length
f = tf.pad(f, [[0, 0], [0, pad_size]])
if only_return_complete:
f = tf.boolean_mask(f, complete_samples)
num_face_indices = tf.boolean_mask(num_face_indices, complete_samples)
context = tf.nest.map_structure(
lambda x: tf.boolean_mask(x, complete_samples), context)
complete_samples = tf.boolean_mask(complete_samples, complete_samples)
# outputs
outputs = {
'context': context,
'completed': complete_samples,
'faces': f,
'num_face_indices': num_face_indices,
}
return outputs