# Copyright 2020 Deepmind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Modules and networks for mesh generation.""" import sonnet as snt from tensor2tensor.layers import common_attention from tensor2tensor.layers import common_layers import tensorflow.compat.v1 as tf from tensorflow.python.framework import function import tensorflow_probability as tfp tfd = tfp.distributions tfb = tfp.bijectors def dequantize_verts(verts, n_bits, add_noise=False): """Quantizes vertices and outputs integers with specified n_bits.""" min_range = -0.5 max_range = 0.5 range_quantize = 2**n_bits - 1 verts = tf.cast(verts, tf.float32) verts = verts * (max_range - min_range) / range_quantize + min_range if add_noise: verts += tf.random_uniform(tf.shape(verts)) * (1 / float(range_quantize)) return verts def quantize_verts(verts, n_bits): """Dequantizes integer vertices to floats.""" min_range = -0.5 max_range = 0.5 range_quantize = 2**n_bits - 1 verts_quantize = ( (verts - min_range) * range_quantize / (max_range - min_range)) return tf.cast(verts_quantize, tf.int32) def top_k_logits(logits, k): """Masks logits such that logits not in top-k are small.""" if k == 0: return logits else: values, _ = tf.math.top_k(logits, k=k) k_largest = tf.reduce_min(values) logits = tf.where(tf.less_equal(logits, k_largest), tf.ones_like(logits)*-1e9, logits) return logits def top_p_logits(logits, p): """Masks logits using nucleus (top-p) sampling.""" if p == 1: return logits else: logit_shape = tf.shape(logits) seq, dim = logit_shape[1], logit_shape[2] logits = tf.reshape(logits, [-1, dim]) sort_indices = tf.argsort(logits, axis=-1, direction='DESCENDING') probs = tf.gather(tf.nn.softmax(logits), sort_indices, batch_dims=1) cumprobs = tf.cumsum(probs, axis=-1, exclusive=True) # The top 1 candidate always will not be masked. # This way ensures at least 1 indices will be selected. sort_mask = tf.cast(tf.greater(cumprobs, p), logits.dtype) batch_indices = tf.tile( tf.expand_dims(tf.range(tf.shape(logits)[0]), axis=-1), [1, dim]) top_p_mask = tf.scatter_nd( tf.stack([batch_indices, sort_indices], axis=-1), sort_mask, tf.shape(logits)) logits -= top_p_mask * 1e9 return tf.reshape(logits, [-1, seq, dim]) _function_cache = {} # For multihead_self_attention_memory_efficient def multihead_self_attention_memory_efficient(x, bias, num_heads, head_size=None, cache=None, epsilon=1e-6, forget=True, test_vars=None, name=None): """Memory-efficient Multihead scaled-dot-product self-attention. Based on Tensor2Tensor version but adds optional caching. Returns multihead-self-attention(layer_norm(x)) Computes one attention head at a time to avoid exhausting memory. If forget=True, then forget all forwards activations and recompute on the backwards pass. Args: x: a Tensor with shape [batch, length, input_size] bias: an attention bias tensor broadcastable to [batch, 1, length, length] num_heads: an integer head_size: an optional integer - defaults to input_size/num_heads cache: Optional dict containing tensors which are the results of previous attentions, used for fast decoding. Expects the dict to contain two keys ('k' and 'v'), for the initial call the values for these keys should be empty Tensors of the appropriate shape. 'k' [batch_size, 0, key_channels] 'v' [batch_size, 0, value_channels] epsilon: a float, for layer norm forget: a boolean - forget forwards activations and recompute on backprop test_vars: optional tuple of variables for testing purposes name: an optional string Returns: A Tensor. """ io_size = x.get_shape().as_list()[-1] if head_size is None: assert io_size % num_heads == 0 head_size = io_size / num_heads def forward_internal(x, wqkv, wo, attention_bias, norm_scale, norm_bias): """Forward function.""" n = common_layers.layer_norm_compute(x, epsilon, norm_scale, norm_bias) wqkv_split = tf.unstack(wqkv, num=num_heads) wo_split = tf.unstack(wo, num=num_heads) y = 0 if cache is not None: cache_k = [] cache_v = [] for h in range(num_heads): with tf.control_dependencies([y] if h > 0 else []): combined = tf.nn.conv1d(n, wqkv_split[h], 1, 'SAME') q, k, v = tf.split(combined, 3, axis=2) if cache is not None: k = tf.concat([cache['k'][:, h], k], axis=1) v = tf.concat([cache['v'][:, h], v], axis=1) cache_k.append(k) cache_v.append(v) o = common_attention.scaled_dot_product_attention_simple( q, k, v, attention_bias) y += tf.nn.conv1d(o, wo_split[h], 1, 'SAME') if cache is not None: cache['k'] = tf.stack(cache_k, axis=1) cache['v'] = tf.stack(cache_v, axis=1) return y key = ( 'multihead_self_attention_memory_efficient %s %s' % (num_heads, epsilon)) if not forget: forward_fn = forward_internal elif key in _function_cache: forward_fn = _function_cache[key] else: @function.Defun(compiled=True) def grad_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias, dy): """Custom gradient function.""" with tf.control_dependencies([dy]): n = common_layers.layer_norm_compute(x, epsilon, norm_scale, norm_bias) wqkv_split = tf.unstack(wqkv, num=num_heads) wo_split = tf.unstack(wo, num=num_heads) deps = [] dwqkvs = [] dwos = [] dn = 0 for h in range(num_heads): with tf.control_dependencies(deps): combined = tf.nn.conv1d(n, wqkv_split[h], 1, 'SAME') q, k, v = tf.split(combined, 3, axis=2) o = common_attention.scaled_dot_product_attention_simple( q, k, v, attention_bias) partial_y = tf.nn.conv1d(o, wo_split[h], 1, 'SAME') pdn, dwqkvh, dwoh = tf.gradients( ys=[partial_y], xs=[n, wqkv_split[h], wo_split[h]], grad_ys=[dy]) dn += pdn dwqkvs.append(dwqkvh) dwos.append(dwoh) deps = [dn, dwqkvh, dwoh] dwqkv = tf.stack(dwqkvs) dwo = tf.stack(dwos) with tf.control_dependencies(deps): dx, dnorm_scale, dnorm_bias = tf.gradients( ys=[n], xs=[x, norm_scale, norm_bias], grad_ys=[dn]) return (dx, dwqkv, dwo, tf.zeros_like(attention_bias), dnorm_scale, dnorm_bias) @function.Defun( grad_func=grad_fn, compiled=True, separate_compiled_gradients=True) def forward_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias): return forward_internal(x, wqkv, wo, attention_bias, norm_scale, norm_bias) _function_cache[key] = forward_fn if bias is not None: bias = tf.squeeze(bias, 1) with tf.variable_scope(name, default_name='multihead_attention', values=[x]): if test_vars is not None: wqkv, wo, norm_scale, norm_bias = list(test_vars) else: wqkv = tf.get_variable( 'wqkv', [num_heads, 1, io_size, 3 * head_size], initializer=tf.random_normal_initializer(stddev=io_size**-0.5)) wo = tf.get_variable( 'wo', [num_heads, 1, head_size, io_size], initializer=tf.random_normal_initializer( stddev=(head_size * num_heads)**-0.5)) norm_scale, norm_bias = common_layers.layer_norm_vars(io_size) y = forward_fn(x, wqkv, wo, bias, norm_scale, norm_bias) y.set_shape(x.get_shape()) # pytype: disable=attribute-error return y class TransformerEncoder(snt.AbstractModule): """Transformer encoder. Sonnet Transformer encoder module as described in Vaswani et al. 2017. Uses the Tensor2Tensor multihead_attention function for full self attention (no masking). Layer norm is applied inside the residual path as in sparse transformers (Child 2019). This module expects inputs to be already embedded, and does not add position embeddings. """ def __init__(self, hidden_size=256, fc_size=1024, num_heads=4, layer_norm=True, num_layers=8, dropout_rate=0.2, re_zero=True, memory_efficient=False, name='transformer_encoder'): """Initializes TransformerEncoder. Args: hidden_size: Size of embedding vectors. fc_size: Size of fully connected layer. num_heads: Number of attention heads. layer_norm: If True, apply layer normalization num_layers: Number of Transformer blocks, where each block contains a multi-head attention layer and a MLP. dropout_rate: Dropout rate applied immediately after the ReLU in each fully-connected layer. re_zero: If True, alpha scale residuals with zero init. memory_efficient: If True, recompute gradients for memory savings. name: Name of variable scope """ super(TransformerEncoder, self).__init__(name=name) self.hidden_size = hidden_size self.num_heads = num_heads self.layer_norm = layer_norm self.fc_size = fc_size self.num_layers = num_layers self.dropout_rate = dropout_rate self.re_zero = re_zero self.memory_efficient = memory_efficient def _build(self, inputs, is_training=False): """Passes inputs through Transformer encoder network. Args: inputs: Tensor of shape [batch_size, sequence_length, embed_size]. Zero embeddings are masked in self-attention. is_training: If True, dropout is applied. Returns: output: Tensor of shape [batch_size, sequence_length, embed_size]. """ if is_training: dropout_rate = self.dropout_rate else: dropout_rate = 0. # Identify elements with all zeros as padding, and create bias to mask # out padding elements in self attention. encoder_padding = common_attention.embedding_to_padding(inputs) encoder_self_attention_bias = ( common_attention.attention_bias_ignore_padding(encoder_padding)) x = inputs for layer_num in range(self.num_layers): with tf.variable_scope('layer_{}'.format(layer_num)): # Multihead self-attention from Tensor2Tensor. res = x if self.memory_efficient: res = multihead_self_attention_memory_efficient( res, bias=encoder_self_attention_bias, num_heads=self.num_heads, head_size=self.hidden_size // self.num_heads, forget=True if is_training else False, name='self_attention' ) else: if self.layer_norm: res = common_layers.layer_norm(res, name='self_attention') res = common_attention.multihead_attention( res, memory_antecedent=None, bias=encoder_self_attention_bias, total_key_depth=self.hidden_size, total_value_depth=self.hidden_size, output_depth=self.hidden_size, num_heads=self.num_heads, dropout_rate=0., make_image_summary=False, name='self_attention') if self.re_zero: res *= tf.get_variable('self_attention/alpha', initializer=0.) if dropout_rate: res = tf.nn.dropout(res, rate=dropout_rate) x += res # MLP res = x if self.layer_norm: res = common_layers.layer_norm(res, name='fc') res = tf.layers.dense( res, self.fc_size, activation=tf.nn.relu, name='fc_1') res = tf.layers.dense(res, self.hidden_size, name='fc_2') if self.re_zero: res *= tf.get_variable('fc/alpha', initializer=0.) if dropout_rate: res = tf.nn.dropout(res, rate=dropout_rate) x += res if self.layer_norm: output = common_layers.layer_norm(x, name='output') else: output = x return output class TransformerDecoder(snt.AbstractModule): """Transformer decoder. Sonnet Transformer decoder module as described in Vaswani et al. 2017. Uses the Tensor2Tensor multihead_attention function for masked self attention, and non-masked cross attention attention. Layer norm is applied inside the residual path as in sparse transformers (Child 2019). This module expects inputs to be already embedded, and does not add position embeddings. """ def __init__(self, hidden_size=256, fc_size=1024, num_heads=4, layer_norm=True, num_layers=8, dropout_rate=0.2, re_zero=True, memory_efficient=False, name='transformer_decoder'): """Initializes TransformerDecoder. Args: hidden_size: Size of embedding vectors. fc_size: Size of fully connected layer. num_heads: Number of attention heads. layer_norm: If True, apply layer normalization. If mem_efficient_attention is True, then layer norm is always applied. num_layers: Number of Transformer blocks, where each block contains a multi-head attention layer and a MLP. dropout_rate: Dropout rate applied immediately after the ReLU in each fully-connected layer. re_zero: If True, alpha scale residuals with zero init. memory_efficient: If True, recompute gradients for memory savings. name: Name of variable scope """ super(TransformerDecoder, self).__init__(name=name) self.hidden_size = hidden_size self.num_heads = num_heads self.layer_norm = layer_norm self.fc_size = fc_size self.num_layers = num_layers self.dropout_rate = dropout_rate self.re_zero = re_zero self.memory_efficient = memory_efficient def _build(self, inputs, sequential_context_embeddings=None, is_training=False, cache=None): """Passes inputs through Transformer decoder network. Args: inputs: Tensor of shape [batch_size, sequence_length, embed_size]. Zero embeddings are masked in self-attention. sequential_context_embeddings: Optional tensor with global context (e.g image embeddings) of shape [batch_size, context_seq_length, context_embed_size]. is_training: If True, dropout is applied. cache: Optional dict containing tensors which are the results of previous attentions, used for fast decoding. Expects the dict to contain two keys ('k' and 'v'), for the initial call the values for these keys should be empty Tensors of the appropriate shape. 'k' [batch_size, 0, key_channels] 'v' [batch_size, 0, value_channels] Returns: output: Tensor of shape [batch_size, sequence_length, embed_size]. """ if is_training: dropout_rate = self.dropout_rate else: dropout_rate = 0. # create bias to mask future elements for causal self-attention. seq_length = tf.shape(inputs)[1] decoder_self_attention_bias = common_attention.attention_bias_lower_triangle( seq_length) # If using sequential_context, identify elements with all zeros as padding, # and create bias to mask out padding elements in self attention. if sequential_context_embeddings is not None: encoder_padding = common_attention.embedding_to_padding( sequential_context_embeddings) encoder_decoder_attention_bias = ( common_attention.attention_bias_ignore_padding(encoder_padding)) x = inputs for layer_num in range(self.num_layers): with tf.variable_scope('layer_{}'.format(layer_num)): # If using cached decoding, access cache for current layer, and create # bias that enables un-masked attention into the cache if cache is not None: layer_cache = cache[layer_num] layer_decoder_bias = tf.zeros([1, 1, 1, 1]) # Otherwise use standard masked bias else: layer_cache = None layer_decoder_bias = decoder_self_attention_bias # Multihead self-attention from Tensor2Tensor. res = x if self.memory_efficient: res = multihead_self_attention_memory_efficient( res, bias=layer_decoder_bias, cache=layer_cache, num_heads=self.num_heads, head_size=self.hidden_size // self.num_heads, forget=True if is_training else False, name='self_attention' ) else: if self.layer_norm: res = common_layers.layer_norm(res, name='self_attention') res = common_attention.multihead_attention( res, memory_antecedent=None, bias=layer_decoder_bias, total_key_depth=self.hidden_size, total_value_depth=self.hidden_size, output_depth=self.hidden_size, num_heads=self.num_heads, cache=layer_cache, dropout_rate=0., make_image_summary=False, name='self_attention') if self.re_zero: res *= tf.get_variable('self_attention/alpha', initializer=0.) if dropout_rate: res = tf.nn.dropout(res, rate=dropout_rate) x += res # Optional cross attention into sequential context if sequential_context_embeddings is not None: res = x if self.layer_norm: res = common_layers.layer_norm(res, name='cross_attention') res = common_attention.multihead_attention( res, memory_antecedent=sequential_context_embeddings, bias=encoder_decoder_attention_bias, total_key_depth=self.hidden_size, total_value_depth=self.hidden_size, output_depth=self.hidden_size, num_heads=self.num_heads, dropout_rate=0., make_image_summary=False, name='cross_attention') if self.re_zero: res *= tf.get_variable('cross_attention/alpha', initializer=0.) if dropout_rate: res = tf.nn.dropout(res, rate=dropout_rate) x += res # FC layers res = x if self.layer_norm: res = common_layers.layer_norm(res, name='fc') res = tf.layers.dense( res, self.fc_size, activation=tf.nn.relu, name='fc_1') res = tf.layers.dense(res, self.hidden_size, name='fc_2') if self.re_zero: res *= tf.get_variable('fc/alpha', initializer=0.) if dropout_rate: res = tf.nn.dropout(res, rate=dropout_rate) x += res if self.layer_norm: output = common_layers.layer_norm(x, name='output') else: output = x return output def create_init_cache(self, batch_size): """Creates empty cache dictionary for use in fast decoding.""" def compute_cache_shape_invariants(tensor): """Helper function to get dynamic shapes for cache tensors.""" shape_list = tensor.shape.as_list() if len(shape_list) == 4: return tf.TensorShape( [shape_list[0], shape_list[1], None, shape_list[3]]) elif len(shape_list) == 3: return tf.TensorShape([shape_list[0], None, shape_list[2]]) # Build cache k = common_attention.split_heads( tf.zeros([batch_size, 0, self.hidden_size]), self.num_heads) v = common_attention.split_heads( tf.zeros([batch_size, 0, self.hidden_size]), self.num_heads) cache = [{'k': k, 'v': v} for _ in range(self.num_layers)] shape_invariants = tf.nest.map_structure( compute_cache_shape_invariants, cache) return cache, shape_invariants def conv_residual_block(inputs, output_channels=None, downsample=False, kernel_size=3, re_zero=True, dropout_rate=0., name='conv_residual_block'): """Convolutional block with residual connections for 2D or 3D inputs. Args: inputs: Input tensor of shape [batch_size, height, width, channels] or [batch_size, height, width, depth, channels]. output_channels: Number of output channels. downsample: If True, downsample by 1/2 in this block. kernel_size: Spatial size of convolutional kernels. re_zero: If True, alpha scale residuals with zero init. dropout_rate: Dropout rate applied after second ReLU in residual path. name: Name for variable scope. Returns: outputs: Output tensor of shape [batch_size, height, width, output_channels] or [batch_size, height, width, depth, output_channels]. """ with tf.variable_scope(name): input_shape = inputs.get_shape().as_list() num_dims = len(input_shape) - 2 if num_dims == 2: conv = tf.layers.conv2d elif num_dims == 3: conv = tf.layers.conv3d input_channels = input_shape[-1] if output_channels is None: output_channels = input_channels if downsample: shortcut = conv( inputs, filters=output_channels, strides=2, kernel_size=kernel_size, padding='same', name='conv_shortcut') else: shortcut = inputs res = inputs res = tf.nn.relu(res) res = conv( res, filters=input_channels, kernel_size=kernel_size, padding='same', name='conv_1') res = tf.nn.relu(res) if dropout_rate: res = tf.nn.dropout(res, rate=dropout_rate) if downsample: out_strides = 2 else: out_strides = 1 res = conv( res, filters=output_channels, kernel_size=kernel_size, padding='same', strides=out_strides, name='conv_2') if re_zero: res *= tf.get_variable('alpha', initializer=0.) return shortcut + res class ResNet(snt.AbstractModule): """ResNet architecture for 2D image or 3D voxel inputs.""" def __init__(self, num_dims, hidden_sizes=(64, 256), num_blocks=(2, 2), dropout_rate=0.1, re_zero=True, name='res_net'): """Initializes ResNet. Args: num_dims: Number of spatial dimensions. 2 for images or 3 for voxels. hidden_sizes: Sizes of hidden layers in resnet blocks. num_blocks: Number of resnet blocks at each size. dropout_rate: Dropout rate applied immediately after the ReLU in each fully-connected layer. re_zero: If True, alpha scale residuals with zero init. name: Name of variable scope """ super(ResNet, self).__init__(name=name) self.num_dims = num_dims self.hidden_sizes = hidden_sizes self.num_blocks = num_blocks self.dropout_rate = dropout_rate self.re_zero = re_zero def _build(self, inputs, is_training=False): """Passes inputs through resnet. Args: inputs: Tensor of shape [batch_size, height, width, channels] or [batch_size, height, width, depth, channels]. is_training: If True, dropout is applied. Returns: output: Tensor of shape [batch_size, height, width, depth, output_size]. """ if is_training: dropout_rate = self.dropout_rate else: dropout_rate = 0. # Initial projection with large kernel as in original resnet architecture if self.num_dims == 3: conv = tf.layers.conv3d elif self.num_dims == 2: conv = tf.layers.conv2d x = conv( inputs, filters=self.hidden_sizes[0], kernel_size=7, strides=2, padding='same', name='conv_input') if self.num_dims == 2: x = tf.layers.max_pooling2d( x, strides=2, pool_size=3, padding='same', name='pool_input') for d, (hidden_size, blocks) in enumerate(zip(self.hidden_sizes, self.num_blocks)): with tf.variable_scope('resolution_{}'.format(d)): # Downsample at the start of each collection of blocks x = conv_residual_block( x, downsample=False if d == 0 else True, dropout_rate=dropout_rate, output_channels=hidden_size, re_zero=self.re_zero, name='block_1_downsample') for i in range(blocks - 1): x = conv_residual_block( x, dropout_rate=dropout_rate, output_channels=hidden_size, re_zero=self.re_zero, name='block_{}'.format(i + 2)) return x class VertexModel(snt.AbstractModule): """Autoregressive generative model of quantized mesh vertices. Operates on flattened vertex sequences with a stopping token: [z_0, y_0, x_0, z_1, y_1, x_1, ..., z_n, y_n, z_n, STOP] Input vertex coordinates are embedded and tagged with learned coordinate and position indicators. A transformer decoder outputs logits for a quantized vertex distribution. """ def __init__(self, decoder_config, quantization_bits, class_conditional=False, num_classes=55, max_num_input_verts=2500, use_discrete_embeddings=True, name='vertex_model'): """Initializes VertexModel. Args: decoder_config: Dictionary with TransformerDecoder config quantization_bits: Number of quantization used in mesh preprocessing. class_conditional: If True, then condition on learned class embeddings. num_classes: Number of classes to condition on. max_num_input_verts: Maximum number of vertices. Used for learned position embeddings. use_discrete_embeddings: If True, use discrete rather than continuous vertex embeddings. name: Name of variable scope """ super(VertexModel, self).__init__(name=name) self.embedding_dim = decoder_config['hidden_size'] self.class_conditional = class_conditional self.num_classes = num_classes self.max_num_input_verts = max_num_input_verts self.quantization_bits = quantization_bits self.use_discrete_embeddings = use_discrete_embeddings with self._enter_variable_scope(): self.decoder = TransformerDecoder(**decoder_config) @snt.reuse_variables def _embed_class_label(self, labels): """Embeds class label with learned embedding matrix.""" init_dict = {'embeddings': tf.glorot_uniform_initializer} return snt.Embed( vocab_size=self.num_classes, embed_dim=self.embedding_dim, initializers=init_dict, densify_gradients=True, name='class_label')(labels) @snt.reuse_variables def _prepare_context(self, context, is_training=False): """Prepare class label context.""" if self.class_conditional: global_context_embedding = self._embed_class_label(context['class_label']) else: global_context_embedding = None return global_context_embedding, None @snt.reuse_variables def _embed_inputs(self, vertices, global_context_embedding=None): """Embeds flat vertices and adds position and coordinate information.""" # Dequantize inputs and get shapes input_shape = tf.shape(vertices) batch_size, seq_length = input_shape[0], input_shape[1] # Coord indicators (x, y, z) coord_embeddings = snt.Embed( vocab_size=3, embed_dim=self.embedding_dim, initializers={'embeddings': tf.glorot_uniform_initializer}, densify_gradients=True, name='coord_embeddings')(tf.mod(tf.range(seq_length), 3)) # Position embeddings pos_embeddings = snt.Embed( vocab_size=self.max_num_input_verts, embed_dim=self.embedding_dim, initializers={'embeddings': tf.glorot_uniform_initializer}, densify_gradients=True, name='coord_embeddings')(tf.floordiv(tf.range(seq_length), 3)) # Discrete vertex value embeddings if self.use_discrete_embeddings: vert_embeddings = snt.Embed( vocab_size=2**self.quantization_bits + 1, embed_dim=self.embedding_dim, initializers={'embeddings': tf.glorot_uniform_initializer}, densify_gradients=True, name='value_embeddings')(vertices) # Continuous vertex value embeddings else: vert_embeddings = tf.layers.dense( dequantize_verts(vertices[..., None], self.quantization_bits), self.embedding_dim, use_bias=True, name='value_embeddings') # Step zero embeddings if global_context_embedding is None: zero_embed = tf.get_variable( 'embed_zero', shape=[1, 1, self.embedding_dim]) zero_embed_tiled = tf.tile(zero_embed, [batch_size, 1, 1]) else: zero_embed_tiled = global_context_embedding[:, None] # Aggregate embeddings embeddings = vert_embeddings + (coord_embeddings + pos_embeddings)[None] embeddings = tf.concat([zero_embed_tiled, embeddings], axis=1) return embeddings @snt.reuse_variables def _project_to_logits(self, inputs): """Projects transformer outputs to logits for predictive distribution.""" return tf.layers.dense( inputs, 2**self.quantization_bits + 1, # + 1 for stopping token use_bias=True, kernel_initializer=tf.zeros_initializer(), name='project_to_logits') @snt.reuse_variables def _create_dist(self, vertices, global_context_embedding=None, sequential_context_embeddings=None, temperature=1., top_k=0, top_p=1., is_training=False, cache=None): """Outputs categorical dist for quantized vertex coordinates.""" # Embed inputs decoder_inputs = self._embed_inputs(vertices, global_context_embedding) if cache is not None: decoder_inputs = decoder_inputs[:, -1:] # pass through decoder outputs = self.decoder( decoder_inputs, cache=cache, sequential_context_embeddings=sequential_context_embeddings, is_training=is_training) # Get logits and optionally process for sampling logits = self._project_to_logits(outputs) logits /= temperature logits = top_k_logits(logits, top_k) logits = top_p_logits(logits, top_p) cat_dist = tfd.Categorical(logits=logits) return cat_dist def _build(self, batch, is_training=False): """Pass batch through vertex model and get log probabilities under model. Args: batch: Dictionary containing: 'vertices_flat': int32 vertex tensors of shape [batch_size, seq_length]. is_training: If True, use dropout. Returns: pred_dist: tfd.Categorical predictive distribution with batch shape [batch_size, seq_length]. """ global_context, seq_context = self._prepare_context( batch, is_training=is_training) pred_dist = self._create_dist( batch['vertices_flat'][:, :-1], # Last element not used for preds global_context_embedding=global_context, sequential_context_embeddings=seq_context, is_training=is_training) return pred_dist def sample(self, num_samples, context=None, max_sample_length=None, temperature=1., top_k=0, top_p=1., recenter_verts=True, only_return_complete=True): """Autoregressive sampling with caching. Args: num_samples: Number of samples to produce. context: Dictionary of context, such as class labels. See _prepare_context for details. max_sample_length: Maximum length of sampled vertex sequences. Sequences that do not complete are truncated. temperature: Scalar softmax temperature > 0. top_k: Number of tokens to keep for top-k sampling. top_p: Proportion of probability mass to keep for top-p sampling. recenter_verts: If True, center vertex samples around origin. This should be used if model is trained using shift augmentations. only_return_complete: If True, only return completed samples. Otherwise return all samples along with completed indicator. Returns: outputs: Output dictionary with fields: 'completed': Boolean tensor of shape [num_samples]. If True then corresponding sample completed within max_sample_length. 'vertices': Tensor of samples with shape [num_samples, num_verts, 3]. 'num_vertices': Tensor indicating number of vertices for each example in padded vertex samples. 'vertices_mask': Tensor of shape [num_samples, num_verts] that masks corresponding invalid elements in 'vertices'. """ # Obtain context for decoder global_context, seq_context = self._prepare_context( context, is_training=False) # num_samples is the minimum value of num_samples and the batch size of # context inputs (if present). if global_context is not None: num_samples = tf.minimum(num_samples, tf.shape(global_context)[0]) global_context = global_context[:num_samples] if seq_context is not None: seq_context = seq_context[:num_samples] elif seq_context is not None: num_samples = tf.minimum(num_samples, tf.shape(seq_context)[0]) seq_context = seq_context[:num_samples] def _loop_body(i, samples, cache): """While-loop body for autoregression calculation.""" cat_dist = self._create_dist( samples, global_context_embedding=global_context, sequential_context_embeddings=seq_context, cache=cache, temperature=temperature, top_k=top_k, top_p=top_p) next_sample = cat_dist.sample() samples = tf.concat([samples, next_sample], axis=1) return i + 1, samples, cache def _stopping_cond(i, samples, cache): """Stopping condition for sampling while-loop.""" del i, cache # Unused return tf.reduce_any(tf.reduce_all(tf.not_equal(samples, 0), axis=-1)) # Initial values for loop variables samples = tf.zeros([num_samples, 0], dtype=tf.int32) max_sample_length = max_sample_length or self.max_num_input_verts cache, cache_shape_invariants = self.decoder.create_init_cache(num_samples) _, v, _ = tf.while_loop( cond=_stopping_cond, body=_loop_body, loop_vars=(0, samples, cache), shape_invariants=(tf.TensorShape([]), tf.TensorShape([None, None]), cache_shape_invariants), maximum_iterations=max_sample_length * 3 + 1, back_prop=False, parallel_iterations=1) # Check if samples completed. Samples are complete if the stopping token # is produced. completed = tf.reduce_any(tf.equal(v, 0), axis=-1) # Get the number of vertices in the sample. This requires finding the # index of the stopping token. For complete samples use to argmax to get # first nonzero index. stop_index_completed = tf.argmax( tf.cast(tf.equal(v, 0), tf.int32), axis=-1, output_type=tf.int32) # For incomplete samples the stopping index is just the maximum index. stop_index_incomplete = ( max_sample_length * 3 * tf.ones_like(stop_index_completed)) stop_index = tf.where( completed, stop_index_completed, stop_index_incomplete) num_vertices = tf.floordiv(stop_index, 3) # Convert to 3D vertices by reshaping and re-ordering x -> y -> z v = v[:, :(tf.reduce_max(num_vertices) * 3)] - 1 verts_dequantized = dequantize_verts(v, self.quantization_bits) vertices = tf.reshape(verts_dequantized, [num_samples, -1, 3]) vertices = tf.stack( [vertices[..., 2], vertices[..., 1], vertices[..., 0]], axis=-1) # Pad samples to max sample length. This is required in order to concatenate # Samples across different replicator instances. Pad with stopping tokens # for incomplete samples. pad_size = max_sample_length - tf.shape(vertices)[1] vertices = tf.pad(vertices, [[0, 0], [0, pad_size], [0, 0]]) # 3D Vertex mask vertices_mask = tf.cast( tf.range(max_sample_length)[None] < num_vertices[:, None], tf.float32) if recenter_verts: vert_max = tf.reduce_max( vertices - 1e10 * (1. - vertices_mask)[..., None], axis=1, keepdims=True) vert_min = tf.reduce_min( vertices + 1e10 * (1. - vertices_mask)[..., None], axis=1, keepdims=True) vert_centers = 0.5 * (vert_max + vert_min) vertices -= vert_centers vertices *= vertices_mask[..., None] if only_return_complete: vertices = tf.boolean_mask(vertices, completed) num_vertices = tf.boolean_mask(num_vertices, completed) vertices_mask = tf.boolean_mask(vertices_mask, completed) completed = tf.boolean_mask(completed, completed) # Outputs outputs = { 'completed': completed, 'vertices': vertices, 'num_vertices': num_vertices, 'vertices_mask': vertices_mask, } return outputs class ImageToVertexModel(VertexModel): """Generative model of quantized mesh vertices with image conditioning. Operates on flattened vertex sequences with a stopping token: [z_0, y_0, x_0, z_1, y_1, x_1, ..., z_n, y_n, z_n, STOP] Input vertex coordinates are embedded and tagged with learned coordinate and position indicators. A transformer decoder outputs logits for a quantized vertex distribution. Image inputs are encoded and used to condition the vertex decoder. """ def __init__(self, res_net_config, decoder_config, quantization_bits, use_discrete_embeddings=True, max_num_input_verts=2500, name='image_to_vertex_model'): """Initializes VoxelToVertexModel. Args: res_net_config: Dictionary with ResNet config. decoder_config: Dictionary with TransformerDecoder config. quantization_bits: Number of quantization used in mesh preprocessing. use_discrete_embeddings: If True, use discrete rather than continuous vertex embeddings. max_num_input_verts: Maximum number of vertices. Used for learned position embeddings. name: Name of variable scope """ super(ImageToVertexModel, self).__init__( decoder_config=decoder_config, quantization_bits=quantization_bits, max_num_input_verts=max_num_input_verts, use_discrete_embeddings=use_discrete_embeddings, name=name) with self._enter_variable_scope(): self.res_net = ResNet(num_dims=2, **res_net_config) @snt.reuse_variables def _prepare_context(self, context, is_training=False): # Pass images through encoder image_embeddings = self.res_net( context['image'] - 0.5, is_training=is_training) # Add 2D coordinate grid embedding processed_image_resolution = tf.shape(image_embeddings)[1] x = tf.linspace(-1., 1., processed_image_resolution) image_coords = tf.stack(tf.meshgrid(x, x), axis=-1) image_coord_embeddings = tf.layers.dense( image_coords, self.embedding_dim, use_bias=True, name='image_coord_embeddings') image_embeddings += image_coord_embeddings[None] # Reshape spatial grid to sequence batch_size = tf.shape(image_embeddings)[0] sequential_context_embedding = tf.reshape( image_embeddings, [batch_size, -1, self.embedding_dim]) return None, sequential_context_embedding class VoxelToVertexModel(VertexModel): """Generative model of quantized mesh vertices with voxel conditioning. Operates on flattened vertex sequences with a stopping token: [z_0, y_0, x_0, z_1, y_1, x_1, ..., z_n, y_n, z_n, STOP] Input vertex coordinates are embedded and tagged with learned coordinate and position indicators. A transformer decoder outputs logits for a quantized vertex distribution. Image inputs are encoded and used to condition the vertex decoder. """ def __init__(self, res_net_config, decoder_config, quantization_bits, use_discrete_embeddings=True, max_num_input_verts=2500, name='voxel_to_vertex_model'): """Initializes VoxelToVertexModel. Args: res_net_config: Dictionary with ResNet config. decoder_config: Dictionary with TransformerDecoder config. quantization_bits: Integer number of bits used for vertex quantization. use_discrete_embeddings: If True, use discrete rather than continuous vertex embeddings. max_num_input_verts: Maximum number of vertices. Used for learned position embeddings. name: Name of variable scope """ super(VoxelToVertexModel, self).__init__( decoder_config=decoder_config, quantization_bits=quantization_bits, max_num_input_verts=max_num_input_verts, use_discrete_embeddings=use_discrete_embeddings, name=name) with self._enter_variable_scope(): self.res_net = ResNet(num_dims=3, **res_net_config) @snt.reuse_variables def _prepare_context(self, context, is_training=False): # Embed binary input voxels voxel_embeddings = snt.Embed( vocab_size=2, embed_dim=self.pre_embed_dim, initializers={'embeddings': tf.glorot_uniform_initializer}, densify_gradients=True, name='voxel_embeddings')(context['voxels']) # Pass embedded voxels through voxel encoder voxel_embeddings = self.res_net( voxel_embeddings, is_training=is_training) # Add 3D coordinate grid embedding processed_voxel_resolution = tf.shape(voxel_embeddings)[1] x = tf.linspace(-1., 1., processed_voxel_resolution) voxel_coords = tf.stack(tf.meshgrid(x, x, x), axis=-1) voxel_coord_embeddings = tf.layers.dense( voxel_coords, self.embedding_dim, use_bias=True, name='voxel_coord_embeddings') voxel_embeddings += voxel_coord_embeddings[None] # Reshape spatial grid to sequence batch_size = tf.shape(voxel_embeddings)[0] sequential_context_embedding = tf.reshape( voxel_embeddings, [batch_size, -1, self.embedding_dim]) return None, sequential_context_embedding class FaceModel(snt.AbstractModule): """Autoregressive generative model of n-gon meshes. Operates on sets of input vertices as well as flattened face sequences with new face and stopping tokens: [f_0^0, f_0^1, f_0^2, NEW, f_1^0, f_1^1, ..., STOP] Input vertices are encoded using a Transformer encoder. Input face sequences are embedded and tagged with learned position indicators, as well as their corresponding vertex embeddings. A transformer decoder outputs a pointer which is compared to each vertex embedding to obtain a distribution over vertex indices. """ def __init__(self, encoder_config, decoder_config, class_conditional=True, num_classes=55, decoder_cross_attention=True, use_discrete_vertex_embeddings=True, quantization_bits=8, max_seq_length=5000, name='face_model'): """Initializes FaceModel. Args: encoder_config: Dictionary with TransformerEncoder config. decoder_config: Dictionary with TransformerDecoder config. class_conditional: If True, then condition on learned class embeddings. num_classes: Number of classes to condition on. decoder_cross_attention: If True, the use cross attention from decoder querys into encoder outputs. use_discrete_vertex_embeddings: If True, use discrete vertex embeddings. quantization_bits: Number of quantization bits for discrete vertex embeddings. max_seq_length: Maximum face sequence length. Used for learned position embeddings. name: Name of variable scope """ super(FaceModel, self).__init__(name=name) self.embedding_dim = decoder_config['hidden_size'] self.class_conditional = class_conditional self.num_classes = num_classes self.max_seq_length = max_seq_length self.decoder_cross_attention = decoder_cross_attention self.use_discrete_vertex_embeddings = use_discrete_vertex_embeddings self.quantization_bits = quantization_bits with self._enter_variable_scope(): self.decoder = TransformerDecoder(**decoder_config) self.encoder = TransformerEncoder(**encoder_config) @snt.reuse_variables def _embed_class_label(self, labels): """Embeds class label with learned embedding matrix.""" init_dict = {'embeddings': tf.glorot_uniform_initializer} return snt.Embed( vocab_size=self.num_classes, embed_dim=self.embedding_dim, initializers=init_dict, densify_gradients=True, name='class_label')(labels) @snt.reuse_variables def _prepare_context(self, context, is_training=False): """Prepare class label and vertex context.""" if self.class_conditional: global_context_embedding = self._embed_class_label(context['class_label']) else: global_context_embedding = None vertex_embeddings = self._embed_vertices( context['vertices'], context['vertices_mask'], is_training=is_training) if self.decoder_cross_attention: sequential_context_embeddings = ( vertex_embeddings * tf.pad(context['vertices_mask'], [[0, 0], [2, 0]], constant_values=1)[..., None]) else: sequential_context_embeddings = None return (vertex_embeddings, global_context_embedding, sequential_context_embeddings) @snt.reuse_variables def _embed_vertices(self, vertices, vertices_mask, is_training=False): """Embeds vertices with transformer encoder.""" # num_verts = tf.shape(vertices)[1] if self.use_discrete_vertex_embeddings: vertex_embeddings = 0. verts_quantized = quantize_verts(vertices, self.quantization_bits) for c in range(3): vertex_embeddings += snt.Embed( vocab_size=256, embed_dim=self.embedding_dim, initializers={'embeddings': tf.glorot_uniform_initializer}, densify_gradients=True, name='coord_{}'.format(c))(verts_quantized[..., c]) else: vertex_embeddings = tf.layers.dense( vertices, self.embedding_dim, use_bias=True, name='vertex_embeddings') vertex_embeddings *= vertices_mask[..., None] # Pad vertex embeddings with learned embeddings for stopping and new face # tokens stopping_embeddings = tf.get_variable( 'stopping_embeddings', shape=[1, 2, self.embedding_dim]) stopping_embeddings = tf.tile(stopping_embeddings, [tf.shape(vertices)[0], 1, 1]) vertex_embeddings = tf.concat( [stopping_embeddings, vertex_embeddings], axis=1) # Pass through Transformer encoder vertex_embeddings = self.encoder(vertex_embeddings, is_training=is_training) return vertex_embeddings @snt.reuse_variables def _embed_inputs(self, faces_long, vertex_embeddings, global_context_embedding=None): """Embeds face sequences and adds within and between face positions.""" # Face value embeddings are gathered vertex embeddings face_embeddings = tf.gather(vertex_embeddings, faces_long, batch_dims=1) # Position embeddings pos_embeddings = snt.Embed( vocab_size=self.max_seq_length, embed_dim=self.embedding_dim, initializers={'embeddings': tf.glorot_uniform_initializer}, densify_gradients=True, name='coord_embeddings')(tf.range(tf.shape(faces_long)[1])) # Step zero embeddings batch_size = tf.shape(face_embeddings)[0] if global_context_embedding is None: zero_embed = tf.get_variable( 'embed_zero', shape=[1, 1, self.embedding_dim]) zero_embed_tiled = tf.tile(zero_embed, [batch_size, 1, 1]) else: zero_embed_tiled = global_context_embedding[:, None] # Aggregate embeddings embeddings = face_embeddings + pos_embeddings[None] embeddings = tf.concat([zero_embed_tiled, embeddings], axis=1) return embeddings @snt.reuse_variables def _project_to_pointers(self, inputs): """Projects transformer outputs to pointer vectors.""" return tf.layers.dense( inputs, self.embedding_dim, use_bias=True, kernel_initializer=tf.zeros_initializer(), name='project_to_pointers' ) @snt.reuse_variables def _create_dist(self, vertex_embeddings, vertices_mask, faces_long, global_context_embedding=None, sequential_context_embeddings=None, temperature=1., top_k=0, top_p=1., is_training=False, cache=None): """Outputs categorical dist for vertex indices.""" # Embed inputs decoder_inputs = self._embed_inputs( faces_long, vertex_embeddings, global_context_embedding) # Pass through Transformer decoder if cache is not None: decoder_inputs = decoder_inputs[:, -1:] decoder_outputs = self.decoder( decoder_inputs, cache=cache, sequential_context_embeddings=sequential_context_embeddings, is_training=is_training) # Get pointers pred_pointers = self._project_to_pointers(decoder_outputs) # Get logits and mask logits = tf.matmul(pred_pointers, vertex_embeddings, transpose_b=True) logits /= tf.sqrt(float(self.embedding_dim)) f_verts_mask = tf.pad( vertices_mask, [[0, 0], [2, 0]], constant_values=1.)[:, None] logits *= f_verts_mask logits -= (1. - f_verts_mask) * 1e9 logits /= temperature logits = top_k_logits(logits, top_k) logits = top_p_logits(logits, top_p) return tfd.Categorical(logits=logits) def _build(self, batch, is_training=False): """Pass batch through face model and get log probabilities. Args: batch: Dictionary containing: 'vertices_dequantized': Tensor of shape [batch_size, num_vertices, 3]. 'faces': int32 tensor of shape [batch_size, seq_length] with flattened faces. 'vertices_mask': float32 tensor with shape [batch_size, num_vertices] that masks padded elements in 'vertices'. is_training: If True, use dropout. Returns: pred_dist: tfd.Categorical predictive distribution with batch shape [batch_size, seq_length]. """ vertex_embeddings, global_context, seq_context = self._prepare_context( batch, is_training=is_training) pred_dist = self._create_dist( vertex_embeddings, batch['vertices_mask'], batch['faces'][:, :-1], global_context_embedding=global_context, sequential_context_embeddings=seq_context, is_training=is_training) return pred_dist def sample(self, context, max_sample_length=None, temperature=1., top_k=0, top_p=1., only_return_complete=True): """Sample from face model using caching. Args: context: Dictionary of context, including 'vertices' and 'vertices_mask'. See _prepare_context for details. max_sample_length: Maximum length of sampled vertex sequences. Sequences that do not complete are truncated. temperature: Scalar softmax temperature > 0. top_k: Number of tokens to keep for top-k sampling. top_p: Proportion of probability mass to keep for top-p sampling. only_return_complete: If True, only return completed samples. Otherwise return all samples along with completed indicator. Returns: outputs: Output dictionary with fields: 'completed': Boolean tensor of shape [num_samples]. If True then corresponding sample completed within max_sample_length. 'faces': Tensor of samples with shape [num_samples, num_verts, 3]. 'num_face_indices': Tensor indicating number of vertices for each example in padded vertex samples. """ vertex_embeddings, global_context, seq_context = self._prepare_context( context, is_training=False) num_samples = tf.shape(vertex_embeddings)[0] def _loop_body(i, samples, cache): """While-loop body for autoregression calculation.""" pred_dist = self._create_dist( vertex_embeddings, context['vertices_mask'], samples, global_context_embedding=global_context, sequential_context_embeddings=seq_context, cache=cache, temperature=temperature, top_k=top_k, top_p=top_p) next_sample = pred_dist.sample()[:, -1:] samples = tf.concat([samples, next_sample], axis=1) return i + 1, samples, cache def _stopping_cond(i, samples, cache): """Stopping conditions for autoregressive calculation.""" del i, cache # Unused return tf.reduce_any(tf.reduce_all(tf.not_equal(samples, 0), axis=-1)) # While loop sampling with caching samples = tf.zeros([num_samples, 0], dtype=tf.int32) max_sample_length = max_sample_length or self.max_seq_length cache, cache_shape_invariants = self.decoder.create_init_cache(num_samples) _, f, _ = tf.while_loop( cond=_stopping_cond, body=_loop_body, loop_vars=(0, samples, cache), shape_invariants=(tf.TensorShape([]), tf.TensorShape([None, None]), cache_shape_invariants), back_prop=False, parallel_iterations=1, maximum_iterations=max_sample_length) # Record completed samples complete_samples = tf.reduce_any(tf.equal(f, 0), axis=-1) # Find number of faces sample_length = tf.shape(f)[-1] # Get largest new face (1) index as stopping point for incomplete samples. max_one_ind = tf.reduce_max( tf.range(sample_length)[None] * tf.cast(tf.equal(f, 1), tf.int32), axis=-1) zero_inds = tf.cast( tf.argmax(tf.cast(tf.equal(f, 0), tf.int32), axis=-1), tf.int32) num_face_indices = tf.where(complete_samples, zero_inds, max_one_ind) + 1 # Mask faces beyond stopping token with zeros # This mask has a -1 in order to replace the last new face token with zero faces_mask = tf.cast( tf.range(sample_length)[None] < num_face_indices[:, None] - 1, tf.int32) f *= faces_mask # This is the real mask faces_mask = tf.cast( tf.range(sample_length)[None] < num_face_indices[:, None], tf.int32) # Pad to maximum size with zeros pad_size = max_sample_length - sample_length f = tf.pad(f, [[0, 0], [0, pad_size]]) if only_return_complete: f = tf.boolean_mask(f, complete_samples) num_face_indices = tf.boolean_mask(num_face_indices, complete_samples) context = tf.nest.map_structure( lambda x: tf.boolean_mask(x, complete_samples), context) complete_samples = tf.boolean_mask(complete_samples, complete_samples) # outputs outputs = { 'context': context, 'completed': complete_samples, 'faces': f, 'num_face_indices': num_face_indices, } return outputs