diff --git a/README.md b/README.md index ada23e1..3268d62 100644 --- a/README.md +++ b/README.md @@ -50,6 +50,8 @@ https://deepmind.com/research/publications/ * [Graph Matching Networks for Learning the Similarity of Graph Structured Objects](graph_matching_networks), ICML 2019 * [REGAL: Transfer Learning for Fast Optimization of Computation Graphs](regal) +* [PolyGen: PolyGen: An Autoregressive Generative Model of 3D Meshes](polygen), ICML 2020 + ## Disclaimer diff --git a/polygen/README.md b/polygen/README.md new file mode 100644 index 0000000..2186591 --- /dev/null +++ b/polygen/README.md @@ -0,0 +1,65 @@ +# PolyGen: An Autoregressive Generative Model of 3D Meshes + +![](media/example_samples.png) + +This package provides an implementation of PolyGen as described in: + +> **PolyGen: An Autoregressive Generative Model of 3D Meshes**, *Charlie Nash, Yaroslav Ganin, S. M. Ali Eslami, Peter W. Battaglia*, ICML, 2020. ([abs](https://arxiv.org/abs/2002.10880)) + +PolyGen is a generative model of 3D meshes that sequentially outputs mesh +vertices and faces. PolyGen consists of two parts: A vertex model, that +unconditionally models mesh vertices, and a face model, that models the mesh +faces conditioned on input vertices. The vertex model uses a masked Transformer +decoder to express a distribution over the vertex sequences. For the face model +we combine Transformers with [pointer networks](https://arxiv.org/abs/1506.03134) +to express a distribution over variable length vertex sequences. + +In this repository we provide model code in `modules.py`, as well as data +processing utilities in `data_utils.py`. We also provide Colabs that demo +training PolyGen from scratch on a toy dataset, as well as sampling from a +pre-trained model. + +There are some minor differences between this implementation and the paper: +* We add global information (e.g. class label embeddings) as an additional input in the first sequence position rather than project it at each layer. This reduces parameters, but does not significantly impact performance. +* We use [ReZero](https://arxiv.org/abs/2003.04887) which improves training speed. +* We train with only shifting augmentations, which we find to be as effective as the combination of augmentations described in the paper. This helps to simplify the data pre-processing pipeline. + +## Training Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/polygen/training.ipynb) + +To train a PolyGen model from scratch on a collection of simple meshes use this +colab. This demonstrates the data pre-processing required to create inputs for +the vertex and face models. + +## Sampling pre-trained model Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/polygen/sample-pretrained.ipynb) + +To sample a model pre-trained on [ShapeNet](https://www.shapenet.org/) +use this colab. The model is class-conditional, and is trained on longer +sequence lengths than those described in the paper. This colab uses the +following checkpoints: ([Google Cloud Storage +bucket](https://console.cloud.google.com/storage/browser/deepmind-research-polygen)). + +## Installation + +To install the package locally run: +```bash +git clone https://github.com/deepmind/deepmind-research.git . +cd deepmind-research/polygen +pip install -e . +``` + +## Giving Credit + +If you use this code in your work, we ask you to cite this paper: + +``` +@article{nash2020polygen, + author={Charlie Nash and Yaroslav Ganin and S. M. Ali Eslami and Peter W. Battaglia}, + title={PolyGen: An Autoregressive Generative Model of 3D Meshes}, + journal={ICML}, + year={2020} +} +``` + +## Disclaimer + +This is not an official Google product. diff --git a/polygen/data_utils.py b/polygen/data_utils.py new file mode 100644 index 0000000..1450d40 --- /dev/null +++ b/polygen/data_utils.py @@ -0,0 +1,409 @@ +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Mesh data utilities.""" +import matplotlib.pyplot as plt +import modules +from mpl_toolkits import mplot3d # pylint: disable=unused-import +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +import networkx as nx +import numpy as np +import six +from six.moves import range +import tensorflow.compat.v1 as tf +import tensorflow_probability as tfp +tfd = tfp.distributions + + +def random_shift(vertices, shift_factor=0.25): + """Apply random shift to vertices.""" + max_shift_pos = tf.cast(255 - tf.reduce_max(vertices, axis=0), tf.float32) + max_shift_pos = tf.maximum(max_shift_pos, 1e-9) + + max_shift_neg = tf.cast(tf.reduce_min(vertices, axis=0), tf.float32) + max_shift_neg = tf.maximum(max_shift_neg, 1e-9) + + shift = tfd.TruncatedNormal( + tf.zeros([1, 3]), shift_factor*255, -max_shift_neg, + max_shift_pos).sample() + shift = tf.cast(shift, tf.int32) + vertices += shift + return vertices + + +def make_vertex_model_dataset(ds, apply_random_shift=False): + """Prepare dataset for vertex model training.""" + def _vertex_model_map_fn(example): + vertices = example['vertices'] + + # Randomly shift vertices + if apply_random_shift: + vertices = random_shift(vertices) + + # Re-order vertex coordinates as (z, y, x). + vertices_permuted = tf.stack( + [vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1) + + # Flatten quantized vertices, reindex starting from 1, and pad with a + # zero stopping token. + vertices_flat = tf.reshape(vertices_permuted, [-1]) + example['vertices_flat'] = tf.pad(vertices_flat + 1, [[0, 1]]) + + # Create mask to indicate valid tokens after padding and batching. + example['vertices_flat_mask'] = tf.ones_like( + example['vertices_flat'], dtype=tf.float32) + return example + return ds.map(_vertex_model_map_fn) + + +def make_face_model_dataset( + ds, apply_random_shift=False, shuffle_vertices=True, quantization_bits=8): + """Prepare dataset for face model training.""" + def _face_model_map_fn(example): + vertices = example['vertices'] + + # Randomly shift vertices + if apply_random_shift: + vertices = random_shift(vertices) + example['num_vertices'] = tf.shape(vertices)[0] + + # Optionally shuffle vertices and re-order faces to match + if shuffle_vertices: + permutation = tf.random_shuffle(tf.range(example['num_vertices'])) + vertices = tf.gather(vertices, permutation) + face_permutation = tf.concat( + [tf.constant([0, 1], dtype=tf.int32), tf.argsort(permutation) + 2], + axis=0) + example['faces'] = tf.cast( + tf.gather(face_permutation, example['faces']), tf.int64) + + # Vertices are quantized. So convert to floats for input to face model + example['vertices'] = modules.dequantize_verts(vertices, quantization_bits) + example['vertices_mask'] = tf.ones_like( + example['vertices'][Ellipsis, 0], dtype=tf.float32) + example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32) + return example + return ds.map(_face_model_map_fn) + + +def read_obj(obj_path): + """Read vertices and faces from .obj file.""" + vertex_list = [] + flat_vertices_list = [] + flat_vertices_indices = {} + flat_triangles = [] + + with open(obj_path) as obj_file: + for line in obj_file: + tokens = line.split() + if not tokens: + continue + line_type = tokens[0] + # We skip lines not starting with v or f. + if line_type == 'v': + vertex_list.append([float(x) for x in tokens[1:]]) + elif line_type == 'f': + triangle = [] + for i in range(len(tokens) - 1): + vertex_name = tokens[i + 1] + if vertex_name in flat_vertices_indices: + triangle.append(flat_vertices_indices[vertex_name]) + continue + flat_vertex = [] + for index in six.ensure_str(vertex_name).split('/'): + if not index: + continue + # obj triangle indices are 1 indexed, so subtract 1 here. + flat_vertex += vertex_list[int(index) - 1] + flat_vertex_index = len(flat_vertices_list) + flat_vertices_list.append(flat_vertex) + flat_vertices_indices[vertex_name] = flat_vertex_index + triangle.append(flat_vertex_index) + flat_triangles.append(triangle) + + return np.array(flat_vertices_list, dtype=np.float32), flat_triangles + + +def write_obj(vertices, faces, file_path, transpose=True, scale=1.): + """Write vertices and faces to obj.""" + if transpose: + vertices = vertices[:, [1, 2, 0]] + vertices *= scale + if faces is not None: + if min(min(faces)) == 0: + f_add = 1 + else: + f_add = 0 + with open(file_path, 'w') as f: + for v in vertices: + f.write('v {} {} {}\n'.format(v[0], v[1], v[2])) + for face in faces: + line = 'f' + for i in face: + line += ' {}'.format(i + f_add) + line += '\n' + f.write(line) + + +def quantize_verts(verts, n_bits=8): + """Convert vertices in [-1., 1.] to discrete values in [0, n_bits**2 - 1].""" + min_range = -0.5 + max_range = 0.5 + range_quantize = 2**n_bits - 1 + verts_quantize = (verts - min_range) * range_quantize / ( + max_range - min_range) + return verts_quantize.astype('int32') + + +def dequantize_verts(verts, n_bits=8, add_noise=False): + """Convert quantized vertices to floats.""" + min_range = -0.5 + max_range = 0.5 + range_quantize = 2**n_bits - 1 + verts = verts.astype('float32') + verts = verts * (max_range - min_range) / range_quantize + min_range + if add_noise: + verts += np.random.uniform(size=verts.shape) * (1 / range_quantize) + return verts + + +def face_to_cycles(face): + """Find cycles in face.""" + g = nx.Graph() + for v in range(len(face) - 1): + g.add_edge(face[v], face[v + 1]) + g.add_edge(face[-1], face[0]) + return list(nx.cycle_basis(g)) + + +def flatten_faces(faces): + """Converts from list of faces to flat face array with stopping indices.""" + if not faces: + return np.array([0]) + else: + l = [f + [-1] for f in faces[:-1]] + l += [faces[-1] + [-2]] + return np.array([item for sublist in l for item in sublist]) + 2 # pylint: disable=g-complex-comprehension + + +def unflatten_faces(flat_faces): + """Converts from flat face sequence to a list of separate faces.""" + def group(seq): + g = [] + for el in seq: + if el == 0 or el == -1: + yield g + g = [] + else: + g.append(el - 1) + yield g + outputs = list(group(flat_faces - 1))[:-1] + # Remove empty faces + return [o for o in outputs if len(o) > 2] + + +def center_vertices(vertices): + """Translate the vertices so that bounding box is centered at zero.""" + vert_min = vertices.min(axis=0) + vert_max = vertices.max(axis=0) + vert_center = 0.5 * (vert_min + vert_max) + return vertices - vert_center + + +def normalize_vertices_scale(vertices): + """Scale the vertices so that the long diagonal of the bounding box is one.""" + vert_min = vertices.min(axis=0) + vert_max = vertices.max(axis=0) + extents = vert_max - vert_min + scale = np.sqrt(np.sum(extents**2)) + return vertices / scale + + +def quantize_process_mesh(vertices, faces, tris=None, quantization_bits=8): + """Quantize vertices, remove resulting duplicates and reindex faces.""" + vertices = quantize_verts(vertices, quantization_bits) + vertices, inv = np.unique(vertices, axis=0, return_inverse=True) + + # Sort vertices by z then y then x. + sort_inds = np.lexsort(vertices.T) + vertices = vertices[sort_inds] + + # Re-index faces and tris to re-ordered vertices. + faces = [np.argsort(sort_inds)[inv[f]] for f in faces] + if tris is not None: + tris = np.array([np.argsort(sort_inds)[inv[t]] for t in tris]) + + # Merging duplicate vertices and re-indexing the faces causes some faces to + # contain loops (e.g [2, 3, 5, 2, 4]). Split these faces into distinct + # sub-faces. + sub_faces = [] + for f in faces: + cliques = face_to_cycles(f) + for c in cliques: + c_length = len(c) + # Only append faces with more than two verts. + if c_length > 2: + d = np.argmin(c) + # Cyclically permute faces just that first index is the smallest. + sub_faces.append([c[(d + i) % c_length] for i in range(c_length)]) + faces = sub_faces + if tris is not None: + tris = np.array([v for v in tris if len(set(v)) == len(v)]) + + # Sort faces by lowest vertex indices. If two faces have the same lowest + # index then sort by next lowest and so on. + faces.sort(key=lambda f: tuple(sorted(f))) + if tris is not None: + tris = tris.tolist() + tris.sort(key=lambda f: tuple(sorted(f))) + tris = np.array(tris) + + # After removing degenerate faces some vertices are now unreferenced. + # Remove these. + num_verts = vertices.shape[0] + vert_connected = np.equal( + np.arange(num_verts)[:, None], np.hstack(faces)[None]).any(axis=-1) + vertices = vertices[vert_connected] + + # Re-index faces and tris to re-ordered vertices. + vert_indices = ( + np.arange(num_verts) - np.cumsum(1 - vert_connected.astype('int'))) + faces = [vert_indices[f].tolist() for f in faces] + if tris is not None: + tris = np.array([vert_indices[t].tolist() for t in tris]) + + return vertices, faces, tris + + +def load_process_mesh(mesh_obj_path, quantization_bits=8): + """Load obj file and process.""" + # Load mesh + vertices, faces = read_obj(mesh_obj_path) + + # Transpose so that z-axis is vertical. + vertices = vertices[:, [2, 0, 1]] + + # Translate the vertices so that bounding box is centered at zero. + vertices = center_vertices(vertices) + + # Scale the vertices so that the long diagonal of the bounding box is equal + # to one. + vertices = normalize_vertices_scale(vertices) + + # Quantize and sort vertices, remove resulting duplicates, sort and reindex + # faces. + vertices, faces, _ = quantize_process_mesh( + vertices, faces, quantization_bits=quantization_bits) + + # Flatten faces and add 'new face' = 1 and 'stop' = 0 tokens. + faces = flatten_faces(faces) + + # Discard degenerate meshes without faces. + return { + 'vertices': vertices, + 'faces': faces, + } + + +def plot_meshes(mesh_list, + ax_lims=0.3, + fig_size=4, + el=30, + rot_start=120, + vert_size=10, + vert_alpha=0.75, + n_cols=4): + """Plots mesh data using matplotlib.""" + + n_plot = len(mesh_list) + n_cols = np.minimum(n_plot, n_cols) + n_rows = np.ceil(n_plot / n_cols).astype('int') + fig = plt.figure(figsize=(fig_size * n_cols, fig_size * n_rows)) + for p_inc, mesh in enumerate(mesh_list): + + for key in [ + 'vertices', 'faces', 'vertices_conditional', 'pointcloud', 'class_name' + ]: + if key not in list(mesh.keys()): + mesh[key] = None + + ax = fig.add_subplot(n_rows, n_cols, p_inc + 1, projection='3d') + + if mesh['faces'] is not None: + if mesh['vertices_conditional'] is not None: + face_verts = np.concatenate( + [mesh['vertices_conditional'], mesh['vertices']], axis=0) + else: + face_verts = mesh['vertices'] + collection = [] + for f in mesh['faces']: + collection.append(face_verts[f]) + plt_mesh = Poly3DCollection(collection) + plt_mesh.set_edgecolor((0., 0., 0., 0.3)) + plt_mesh.set_facecolor((1, 0, 0, 0.2)) + ax.add_collection3d(plt_mesh) + + if mesh['vertices'] is not None: + ax.scatter3D( + mesh['vertices'][:, 0], + mesh['vertices'][:, 1], + mesh['vertices'][:, 2], + lw=0., + s=vert_size, + c='g', + alpha=vert_alpha) + + if mesh['vertices_conditional'] is not None: + ax.scatter3D( + mesh['vertices_conditional'][:, 0], + mesh['vertices_conditional'][:, 1], + mesh['vertices_conditional'][:, 2], + lw=0., + s=vert_size, + c='b', + alpha=vert_alpha) + + if mesh['pointcloud'] is not None: + ax.scatter3D( + mesh['pointcloud'][:, 0], + mesh['pointcloud'][:, 1], + mesh['pointcloud'][:, 2], + lw=0., + s=2.5 * vert_size, + c='b', + alpha=1.) + + ax.set_xlim(-ax_lims, ax_lims) + ax.set_ylim(-ax_lims, ax_lims) + ax.set_zlim(-ax_lims, ax_lims) + + ax.view_init(el, rot_start) + + display_string = '' + if mesh['faces'] is not None: + display_string += 'Num. faces: {}\n'.format(len(collection)) + if mesh['vertices'] is not None: + num_verts = mesh['vertices'].shape[0] + if mesh['vertices_conditional'] is not None: + num_verts += mesh['vertices_conditional'].shape[0] + display_string += 'Num. verts: {}\n'.format(num_verts) + if mesh['class_name'] is not None: + display_string += 'Synset: {}'.format(mesh['class_name']) + if mesh['pointcloud'] is not None: + display_string += 'Num. pointcloud: {}\n'.format( + mesh['pointcloud'].shape[0]) + ax.text2D(0.05, 0.8, display_string, transform=ax.transAxes) + plt.subplots_adjust( + left=0., right=1., bottom=0., top=1., wspace=0.025, hspace=0.025) + plt.show() diff --git a/polygen/media/example_samples.png b/polygen/media/example_samples.png new file mode 100644 index 0000000..4fc6778 Binary files /dev/null and b/polygen/media/example_samples.png differ diff --git a/polygen/meshes/cone.obj b/polygen/meshes/cone.obj new file mode 100644 index 0000000..22efc7c --- /dev/null +++ b/polygen/meshes/cone.obj @@ -0,0 +1,66 @@ +v 0.000000 -1.000000 -1.000000 +v 0.195090 -1.000000 -0.980785 +v 0.382683 -1.000000 -0.923880 +v 0.555570 -1.000000 -0.831470 +v 0.707107 -1.000000 -0.707107 +v 0.831470 -1.000000 -0.555570 +v 0.923880 -1.000000 -0.382683 +v 0.980785 -1.000000 -0.195090 +v 1.000000 -1.000000 -0.000000 +v 0.980785 -1.000000 0.195090 +v 0.923880 -1.000000 0.382683 +v 0.831470 -1.000000 0.555570 +v 0.707107 -1.000000 0.707107 +v 0.555570 -1.000000 0.831470 +v 0.382683 -1.000000 0.923880 +v 0.195090 -1.000000 0.980785 +v -0.000000 -1.000000 1.000000 +v -0.195091 -1.000000 0.980785 +v -0.382684 -1.000000 0.923879 +v -0.555571 -1.000000 0.831469 +v -0.707107 -1.000000 0.707106 +v -0.831470 -1.000000 0.555570 +v -0.923880 -1.000000 0.382683 +v 0.000000 1.000000 0.000000 +v -0.980785 -1.000000 0.195089 +v -1.000000 -1.000000 -0.000001 +v -0.980785 -1.000000 -0.195091 +v -0.923879 -1.000000 -0.382684 +v -0.831469 -1.000000 -0.555571 +v -0.707106 -1.000000 -0.707108 +v -0.555569 -1.000000 -0.831470 +v -0.382682 -1.000000 -0.923880 +v -0.195089 -1.000000 -0.980786 +f 1 24 2 +f 2 24 3 +f 3 24 4 +f 4 24 5 +f 5 24 6 +f 6 24 7 +f 7 24 8 +f 8 24 9 +f 9 24 10 +f 10 24 11 +f 11 24 12 +f 12 24 13 +f 13 24 14 +f 14 24 15 +f 15 24 16 +f 16 24 17 +f 17 24 18 +f 18 24 19 +f 19 24 20 +f 20 24 21 +f 21 24 22 +f 22 24 23 +f 23 24 25 +f 25 24 26 +f 26 24 27 +f 27 24 28 +f 28 24 29 +f 29 24 30 +f 30 24 31 +f 31 24 32 +f 32 24 33 +f 33 24 1 +f 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 25 26 27 28 29 30 31 32 33 diff --git a/polygen/meshes/cube.obj b/polygen/meshes/cube.obj new file mode 100644 index 0000000..9855ecd --- /dev/null +++ b/polygen/meshes/cube.obj @@ -0,0 +1,14 @@ +v -1.000000 -1.000000 1.000000 +v -1.000000 1.000000 1.000000 +v -1.000000 -1.000000 -1.000000 +v -1.000000 1.000000 -1.000000 +v 1.000000 -1.000000 1.000000 +v 1.000000 1.000000 1.000000 +v 1.000000 -1.000000 -1.000000 +v 1.000000 1.000000 -1.000000 +f 1 2 4 3 +f 3 4 8 7 +f 7 8 6 5 +f 5 6 2 1 +f 3 7 5 1 +f 8 4 2 6 diff --git a/polygen/meshes/cylinder.obj b/polygen/meshes/cylinder.obj new file mode 100644 index 0000000..4fad316 --- /dev/null +++ b/polygen/meshes/cylinder.obj @@ -0,0 +1,98 @@ +v 0.000000 -1.000000 -1.000000 +v 0.000000 1.000000 -1.000000 +v 0.195090 -1.000000 -0.980785 +v 0.195090 1.000000 -0.980785 +v 0.382683 -1.000000 -0.923880 +v 0.382683 1.000000 -0.923880 +v 0.555570 -1.000000 -0.831470 +v 0.555570 1.000000 -0.831470 +v 0.707107 -1.000000 -0.707107 +v 0.707107 1.000000 -0.707107 +v 0.831470 -1.000000 -0.555570 +v 0.831470 1.000000 -0.555570 +v 0.923880 -1.000000 -0.382683 +v 0.923880 1.000000 -0.382683 +v 0.980785 -1.000000 -0.195090 +v 0.980785 1.000000 -0.195090 +v 1.000000 -1.000000 -0.000000 +v 1.000000 1.000000 -0.000000 +v 0.980785 -1.000000 0.195090 +v 0.980785 1.000000 0.195090 +v 0.923880 -1.000000 0.382683 +v 0.923880 1.000000 0.382683 +v 0.831470 -1.000000 0.555570 +v 0.831470 1.000000 0.555570 +v 0.707107 -1.000000 0.707107 +v 0.707107 1.000000 0.707107 +v 0.555570 -1.000000 0.831470 +v 0.555570 1.000000 0.831470 +v 0.382683 -1.000000 0.923880 +v 0.382683 1.000000 0.923880 +v 0.195090 -1.000000 0.980785 +v 0.195090 1.000000 0.980785 +v -0.000000 -1.000000 1.000000 +v -0.000000 1.000000 1.000000 +v -0.195091 -1.000000 0.980785 +v -0.195091 1.000000 0.980785 +v -0.382684 -1.000000 0.923879 +v -0.382684 1.000000 0.923879 +v -0.555571 -1.000000 0.831469 +v -0.555571 1.000000 0.831469 +v -0.707107 -1.000000 0.707106 +v -0.707107 1.000000 0.707106 +v -0.831470 -1.000000 0.555570 +v -0.831470 1.000000 0.555570 +v -0.923880 -1.000000 0.382683 +v -0.923880 1.000000 0.382683 +v -0.980785 -1.000000 0.195089 +v -0.980785 1.000000 0.195089 +v -1.000000 -1.000000 -0.000001 +v -1.000000 1.000000 -0.000001 +v -0.980785 -1.000000 -0.195091 +v -0.980785 1.000000 -0.195091 +v -0.923879 -1.000000 -0.382684 +v -0.923879 1.000000 -0.382684 +v -0.831469 -1.000000 -0.555571 +v -0.831469 1.000000 -0.555571 +v -0.707106 -1.000000 -0.707108 +v -0.707106 1.000000 -0.707108 +v -0.555569 -1.000000 -0.831470 +v -0.555569 1.000000 -0.831470 +v -0.382682 -1.000000 -0.923880 +v -0.382682 1.000000 -0.923880 +v -0.195089 -1.000000 -0.980786 +v -0.195089 1.000000 -0.980786 +f 1 2 4 3 +f 3 4 6 5 +f 5 6 8 7 +f 7 8 10 9 +f 9 10 12 11 +f 11 12 14 13 +f 13 14 16 15 +f 15 16 18 17 +f 17 18 20 19 +f 19 20 22 21 +f 21 22 24 23 +f 23 24 26 25 +f 25 26 28 27 +f 27 28 30 29 +f 29 30 32 31 +f 31 32 34 33 +f 33 34 36 35 +f 35 36 38 37 +f 37 38 40 39 +f 39 40 42 41 +f 41 42 44 43 +f 43 44 46 45 +f 45 46 48 47 +f 47 48 50 49 +f 49 50 52 51 +f 51 52 54 53 +f 53 54 56 55 +f 55 56 58 57 +f 57 58 60 59 +f 59 60 62 61 +f 4 2 64 62 60 58 56 54 52 50 48 46 44 42 40 38 36 34 32 30 28 26 24 22 20 18 16 14 12 10 8 6 +f 61 62 64 63 +f 63 64 2 1 +f 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39 41 43 45 47 49 51 53 55 57 59 61 63 diff --git a/polygen/meshes/icosphere.obj b/polygen/meshes/icosphere.obj new file mode 100644 index 0000000..1a1506b --- /dev/null +++ b/polygen/meshes/icosphere.obj @@ -0,0 +1,122 @@ +v 0.000000 -1.000000 0.000000 +v 0.723607 -0.447220 0.525725 +v -0.276388 -0.447220 0.850649 +v -0.894426 -0.447216 0.000000 +v -0.276388 -0.447220 -0.850649 +v 0.723607 -0.447220 -0.525725 +v 0.276388 0.447220 0.850649 +v -0.723607 0.447220 0.525725 +v -0.723607 0.447220 -0.525725 +v 0.276388 0.447220 -0.850649 +v 0.894426 0.447216 0.000000 +v 0.000000 1.000000 0.000000 +v -0.162456 -0.850654 0.499995 +v 0.425323 -0.850654 0.309011 +v 0.262869 -0.525738 0.809012 +v 0.850648 -0.525736 0.000000 +v 0.425323 -0.850654 -0.309011 +v -0.525730 -0.850652 0.000000 +v -0.688189 -0.525736 0.499997 +v -0.162456 -0.850654 -0.499995 +v -0.688189 -0.525736 -0.499997 +v 0.262869 -0.525738 -0.809012 +v 0.951058 0.000000 0.309013 +v 0.951058 0.000000 -0.309013 +v 0.000000 0.000000 1.000000 +v 0.587786 0.000000 0.809017 +v -0.951058 0.000000 0.309013 +v -0.587786 0.000000 0.809017 +v -0.587786 0.000000 -0.809017 +v -0.951058 0.000000 -0.309013 +v 0.587786 0.000000 -0.809017 +v 0.000000 0.000000 -1.000000 +v 0.688189 0.525736 0.499997 +v -0.262869 0.525738 0.809012 +v -0.850648 0.525736 0.000000 +v -0.262869 0.525738 -0.809012 +v 0.688189 0.525736 -0.499997 +v 0.162456 0.850654 0.499995 +v 0.525730 0.850652 0.000000 +v -0.425323 0.850654 0.309011 +v -0.425323 0.850654 -0.309011 +v 0.162456 0.850654 -0.499995 +f 1 14 13 +f 2 14 16 +f 1 13 18 +f 1 18 20 +f 1 20 17 +f 2 16 23 +f 3 15 25 +f 4 19 27 +f 5 21 29 +f 6 22 31 +f 2 23 26 +f 3 25 28 +f 4 27 30 +f 5 29 32 +f 6 31 24 +f 7 33 38 +f 8 34 40 +f 9 35 41 +f 10 36 42 +f 11 37 39 +f 39 42 12 +f 39 37 42 +f 37 10 42 +f 42 41 12 +f 42 36 41 +f 36 9 41 +f 41 40 12 +f 41 35 40 +f 35 8 40 +f 40 38 12 +f 40 34 38 +f 34 7 38 +f 38 39 12 +f 38 33 39 +f 33 11 39 +f 24 37 11 +f 24 31 37 +f 31 10 37 +f 32 36 10 +f 32 29 36 +f 29 9 36 +f 30 35 9 +f 30 27 35 +f 27 8 35 +f 28 34 8 +f 28 25 34 +f 25 7 34 +f 26 33 7 +f 26 23 33 +f 23 11 33 +f 31 32 10 +f 31 22 32 +f 22 5 32 +f 29 30 9 +f 29 21 30 +f 21 4 30 +f 27 28 8 +f 27 19 28 +f 19 3 28 +f 25 26 7 +f 25 15 26 +f 15 2 26 +f 23 24 11 +f 23 16 24 +f 16 6 24 +f 17 22 6 +f 17 20 22 +f 20 5 22 +f 20 21 5 +f 20 18 21 +f 18 4 21 +f 18 19 4 +f 18 13 19 +f 13 3 19 +f 16 17 6 +f 16 14 17 +f 14 1 17 +f 13 15 3 +f 13 14 15 +f 14 2 15 diff --git a/polygen/model_test.py b/polygen/model_test.py new file mode 100644 index 0000000..64358fb --- /dev/null +++ b/polygen/model_test.py @@ -0,0 +1,159 @@ +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the PolyGen open-source version.""" +from modules import FaceModel +from modules import VertexModel +import numpy as np +import tensorflow as tf + +_BATCH_SIZE = 4 +_TRANSFORMER_CONFIG = { + 'num_layers': 2, + 'hidden_size': 64, + 'fc_size': 256 +} +_CLASS_CONDITIONAL = True +_NUM_CLASSES = 4 +_NUM_INPUT_VERTS = 50 +_NUM_PAD_VERTS = 10 +_NUM_INPUT_FACE_INDICES = 200 +_QUANTIZATION_BITS = 8 +_VERTEX_MODEL_USE_DISCRETE_EMBEDDINGS = True +_FACE_MODEL_DECODER_CROSS_ATTENTION = True +_FACE_MODEL_DISCRETE_EMBEDDINGS = True +_MAX_SAMPLE_LENGTH_VERTS = 10 +_MAX_SAMPLE_LENGTH_FACES = 10 + + +def _get_vertex_model_batch(): + """Returns batch with placeholders for vertex model inputs.""" + return { + 'class_label': tf.range(_BATCH_SIZE), + 'vertices_flat': tf.placeholder( + dtype=tf.int32, shape=[_BATCH_SIZE, None]), + } + + +def _get_face_model_batch(): + """Returns batch with placeholders for face model inputs.""" + return { + 'vertices': tf.placeholder( + dtype=tf.float32, shape=[_BATCH_SIZE, None, 3]), + 'vertices_mask': tf.placeholder( + dtype=tf.float32, shape=[_BATCH_SIZE, None]), + 'faces': tf.placeholder( + dtype=tf.int32, shape=[_BATCH_SIZE, None]), + } + + +class VertexModelTest(tf.test.TestCase): + + def setUp(self): + """Defines a vertex model.""" + super(VertexModelTest, self).setUp() + self.model = VertexModel( + decoder_config=_TRANSFORMER_CONFIG, + class_conditional=_CLASS_CONDITIONAL, + num_classes=_NUM_CLASSES, + max_num_input_verts=_NUM_INPUT_VERTS, + quantization_bits=_QUANTIZATION_BITS, + use_discrete_embeddings=_VERTEX_MODEL_USE_DISCRETE_EMBEDDINGS) + + def test_model_runs(self): + """Tests if the model runs without crashing.""" + batch = _get_vertex_model_batch() + pred_dist = self.model(batch, is_training=False) + logits = pred_dist.logits + with self.session() as sess: + sess.run(tf.global_variables_initializer()) + vertices_flat = np.random.randint( + 2**_QUANTIZATION_BITS + 1, + size=[_BATCH_SIZE, _NUM_INPUT_VERTS * 3 + 1]) + sess.run(logits, {batch['vertices_flat']: vertices_flat}) + + def test_sample_outputs_range(self): + """Tests if the model produces samples in the correct range.""" + context = {'class_label': tf.zeros((_BATCH_SIZE,), dtype=tf.int32)} + sample_dict = self.model.sample( + _BATCH_SIZE, max_sample_length=_MAX_SAMPLE_LENGTH_VERTS, + context=context) + with self.session() as sess: + sess.run(tf.global_variables_initializer()) + sample_dict_np = sess.run(sample_dict) + in_range = np.logical_and( + 0 <= sample_dict_np['vertices'], + sample_dict_np['vertices'] <= 2**_QUANTIZATION_BITS).all() + self.assertTrue(in_range) + + +class FaceModelTest(tf.test.TestCase): + + def setUp(self): + """Defines a face model.""" + super(FaceModelTest, self).setUp() + self.model = FaceModel( + encoder_config=_TRANSFORMER_CONFIG, + decoder_config=_TRANSFORMER_CONFIG, + class_conditional=False, + max_seq_length=_NUM_INPUT_FACE_INDICES, + decoder_cross_attention=_FACE_MODEL_DECODER_CROSS_ATTENTION, + use_discrete_vertex_embeddings=_FACE_MODEL_DISCRETE_EMBEDDINGS, + quantization_bits=_QUANTIZATION_BITS) + + def test_model_runs(self): + """Tests if the model runs without crashing.""" + batch = _get_face_model_batch() + pred_dist = self.model(batch, is_training=False) + logits = pred_dist.logits + with self.session() as sess: + sess.run(tf.global_variables_initializer()) + vertices = np.random.rand(_BATCH_SIZE, _NUM_INPUT_VERTS, 3) - 0.5 + vertices_mask = np.ones([_BATCH_SIZE, _NUM_INPUT_VERTS]) + faces = np.random.randint( + _NUM_INPUT_VERTS + 2, size=[_BATCH_SIZE, _NUM_INPUT_FACE_INDICES]) + sess.run( + logits, + {batch['vertices']: vertices, + batch['vertices_mask']: vertices_mask, + batch['faces']: faces} + ) + + def test_sample_outputs_range(self): + """Tests if the model produces samples in the correct range.""" + context = _get_face_model_batch() + del context['faces'] + sample_dict = self.model.sample( + context, max_sample_length=_MAX_SAMPLE_LENGTH_FACES) + with self.session() as sess: + sess.run(tf.global_variables_initializer()) + # Pad the vertices in order to test that the face model only outputs + # vertex indices in the unpadded range + vertices = np.pad( + np.random.rand(_BATCH_SIZE, _NUM_INPUT_VERTS, 3) - 0.5, + [[0, 0], [0, _NUM_PAD_VERTS], [0, 0]], mode='constant') + vertices_mask = np.pad( + np.ones([_BATCH_SIZE, _NUM_INPUT_VERTS]), + [[0, 0], [0, _NUM_PAD_VERTS]], mode='constant') + sample_dict_np = sess.run( + sample_dict, + {context['vertices']: vertices, + context['vertices_mask']: vertices_mask}) + in_range = np.logical_and( + 0 <= sample_dict_np['faces'], + sample_dict_np['faces'] <= _NUM_INPUT_VERTS + 1).all() + self.assertTrue(in_range) + +if __name__ == '__main__': + tf.test.main() diff --git a/polygen/modules.py b/polygen/modules.py new file mode 100644 index 0000000..8a848a3 --- /dev/null +++ b/polygen/modules.py @@ -0,0 +1,1506 @@ +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Modules and networks for mesh generation.""" +import sonnet as snt +from tensor2tensor.layers import common_attention +from tensor2tensor.layers import common_layers +import tensorflow.compat.v1 as tf +from tensorflow.python.framework import function +import tensorflow_probability as tfp + +tfd = tfp.distributions +tfb = tfp.bijectors + + +def dequantize_verts(verts, n_bits, add_noise=False): + """Quantizes vertices and outputs integers with specified n_bits.""" + min_range = -0.5 + max_range = 0.5 + range_quantize = 2**n_bits - 1 + verts = tf.cast(verts, tf.float32) + verts = verts * (max_range - min_range) / range_quantize + min_range + if add_noise: + verts += tf.random_uniform(tf.shape(verts)) * (1 / float(range_quantize)) + return verts + + +def quantize_verts(verts, n_bits): + """Dequantizes integer vertices to floats.""" + min_range = -0.5 + max_range = 0.5 + range_quantize = 2**n_bits - 1 + verts_quantize = ( + (verts - min_range) * range_quantize / (max_range - min_range)) + return tf.cast(verts_quantize, tf.int32) + + +def top_k_logits(logits, k): + """Masks logits such that logits not in top-k are small.""" + if k == 0: + return logits + else: + values, _ = tf.math.top_k(logits, k=k) + k_largest = tf.reduce_min(values) + logits = tf.where(tf.less_equal(logits, k_largest), + tf.ones_like(logits)*-1e9, logits) + return logits + + +def top_p_logits(logits, p): + """Masks logits using nucleus (top-p) sampling.""" + if p == 1: + return logits + else: + logit_shape = tf.shape(logits) + seq, dim = logit_shape[1], logit_shape[2] + logits = tf.reshape(logits, [-1, dim]) + sort_indices = tf.argsort(logits, axis=-1, direction='DESCENDING') + probs = tf.gather(tf.nn.softmax(logits), sort_indices, batch_dims=1) + cumprobs = tf.cumsum(probs, axis=-1, exclusive=True) + # The top 1 candidate always will not be masked. + # This way ensures at least 1 indices will be selected. + sort_mask = tf.cast(tf.greater(cumprobs, p), logits.dtype) + batch_indices = tf.tile( + tf.expand_dims(tf.range(tf.shape(logits)[0]), axis=-1), [1, dim]) + top_p_mask = tf.scatter_nd( + tf.stack([batch_indices, sort_indices], axis=-1), sort_mask, + tf.shape(logits)) + logits -= top_p_mask * 1e9 + return tf.reshape(logits, [-1, seq, dim]) + + +_function_cache = {} # For multihead_self_attention_memory_efficient + + +def multihead_self_attention_memory_efficient(x, + bias, + num_heads, + head_size=None, + cache=None, + epsilon=1e-6, + forget=True, + test_vars=None, + name=None): + """Memory-efficient Multihead scaled-dot-product self-attention. + + Based on Tensor2Tensor version but adds optional caching. + + Returns multihead-self-attention(layer_norm(x)) + + Computes one attention head at a time to avoid exhausting memory. + + If forget=True, then forget all forwards activations and recompute on + the backwards pass. + + Args: + x: a Tensor with shape [batch, length, input_size] + bias: an attention bias tensor broadcastable to [batch, 1, length, length] + num_heads: an integer + head_size: an optional integer - defaults to input_size/num_heads + cache: Optional dict containing tensors which are the results of previous + attentions, used for fast decoding. Expects the dict to contain two + keys ('k' and 'v'), for the initial call the values for these keys + should be empty Tensors of the appropriate shape. + 'k' [batch_size, 0, key_channels] 'v' [batch_size, 0, value_channels] + epsilon: a float, for layer norm + forget: a boolean - forget forwards activations and recompute on backprop + test_vars: optional tuple of variables for testing purposes + name: an optional string + + Returns: + A Tensor. + """ + io_size = x.get_shape().as_list()[-1] + if head_size is None: + assert io_size % num_heads == 0 + head_size = io_size / num_heads + + def forward_internal(x, wqkv, wo, attention_bias, norm_scale, norm_bias): + """Forward function.""" + n = common_layers.layer_norm_compute(x, epsilon, norm_scale, norm_bias) + wqkv_split = tf.unstack(wqkv, num=num_heads) + wo_split = tf.unstack(wo, num=num_heads) + y = 0 + if cache is not None: + cache_k = [] + cache_v = [] + for h in range(num_heads): + with tf.control_dependencies([y] if h > 0 else []): + combined = tf.nn.conv1d(n, wqkv_split[h], 1, 'SAME') + q, k, v = tf.split(combined, 3, axis=2) + if cache is not None: + k = tf.concat([cache['k'][:, h], k], axis=1) + v = tf.concat([cache['v'][:, h], v], axis=1) + cache_k.append(k) + cache_v.append(v) + o = common_attention.scaled_dot_product_attention_simple( + q, k, v, attention_bias) + y += tf.nn.conv1d(o, wo_split[h], 1, 'SAME') + if cache is not None: + cache['k'] = tf.stack(cache_k, axis=1) + cache['v'] = tf.stack(cache_v, axis=1) + return y + + key = ( + 'multihead_self_attention_memory_efficient %s %s' % (num_heads, epsilon)) + if not forget: + forward_fn = forward_internal + elif key in _function_cache: + forward_fn = _function_cache[key] + else: + + @function.Defun(compiled=True) + def grad_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias, dy): + """Custom gradient function.""" + with tf.control_dependencies([dy]): + n = common_layers.layer_norm_compute(x, epsilon, norm_scale, norm_bias) + wqkv_split = tf.unstack(wqkv, num=num_heads) + wo_split = tf.unstack(wo, num=num_heads) + deps = [] + dwqkvs = [] + dwos = [] + dn = 0 + for h in range(num_heads): + with tf.control_dependencies(deps): + combined = tf.nn.conv1d(n, wqkv_split[h], 1, 'SAME') + q, k, v = tf.split(combined, 3, axis=2) + o = common_attention.scaled_dot_product_attention_simple( + q, k, v, attention_bias) + partial_y = tf.nn.conv1d(o, wo_split[h], 1, 'SAME') + pdn, dwqkvh, dwoh = tf.gradients( + ys=[partial_y], + xs=[n, wqkv_split[h], wo_split[h]], + grad_ys=[dy]) + dn += pdn + dwqkvs.append(dwqkvh) + dwos.append(dwoh) + deps = [dn, dwqkvh, dwoh] + dwqkv = tf.stack(dwqkvs) + dwo = tf.stack(dwos) + with tf.control_dependencies(deps): + dx, dnorm_scale, dnorm_bias = tf.gradients( + ys=[n], xs=[x, norm_scale, norm_bias], grad_ys=[dn]) + return (dx, dwqkv, dwo, tf.zeros_like(attention_bias), dnorm_scale, + dnorm_bias) + + @function.Defun( + grad_func=grad_fn, compiled=True, separate_compiled_gradients=True) + def forward_fn(x, wqkv, wo, attention_bias, norm_scale, norm_bias): + return forward_internal(x, wqkv, wo, attention_bias, norm_scale, + norm_bias) + + _function_cache[key] = forward_fn + + if bias is not None: + bias = tf.squeeze(bias, 1) + with tf.variable_scope(name, default_name='multihead_attention', values=[x]): + if test_vars is not None: + wqkv, wo, norm_scale, norm_bias = list(test_vars) + else: + wqkv = tf.get_variable( + 'wqkv', [num_heads, 1, io_size, 3 * head_size], + initializer=tf.random_normal_initializer(stddev=io_size**-0.5)) + wo = tf.get_variable( + 'wo', [num_heads, 1, head_size, io_size], + initializer=tf.random_normal_initializer( + stddev=(head_size * num_heads)**-0.5)) + norm_scale, norm_bias = common_layers.layer_norm_vars(io_size) + y = forward_fn(x, wqkv, wo, bias, norm_scale, norm_bias) + y.set_shape(x.get_shape()) # pytype: disable=attribute-error + return y + + +class TransformerEncoder(snt.AbstractModule): + """Transformer encoder. + + Sonnet Transformer encoder module as described in Vaswani et al. 2017. Uses + the Tensor2Tensor multihead_attention function for full self attention + (no masking). Layer norm is applied inside the residual path as in sparse + transformers (Child 2019). + + This module expects inputs to be already embedded, and does not add position + embeddings. + """ + + def __init__(self, + hidden_size=256, + fc_size=1024, + num_heads=4, + layer_norm=True, + num_layers=8, + dropout_rate=0.2, + re_zero=True, + memory_efficient=False, + name='transformer_encoder'): + """Initializes TransformerEncoder. + + Args: + hidden_size: Size of embedding vectors. + fc_size: Size of fully connected layer. + num_heads: Number of attention heads. + layer_norm: If True, apply layer normalization + num_layers: Number of Transformer blocks, where each block contains a + multi-head attention layer and a MLP. + dropout_rate: Dropout rate applied immediately after the ReLU in each + fully-connected layer. + re_zero: If True, alpha scale residuals with zero init. + memory_efficient: If True, recompute gradients for memory savings. + name: Name of variable scope + """ + super(TransformerEncoder, self).__init__(name=name) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.layer_norm = layer_norm + self.fc_size = fc_size + self.num_layers = num_layers + self.dropout_rate = dropout_rate + self.re_zero = re_zero + self.memory_efficient = memory_efficient + + def _build(self, inputs, is_training=False): + """Passes inputs through Transformer encoder network. + + Args: + inputs: Tensor of shape [batch_size, sequence_length, embed_size]. Zero + embeddings are masked in self-attention. + is_training: If True, dropout is applied. + + Returns: + output: Tensor of shape [batch_size, sequence_length, embed_size]. + """ + if is_training: + dropout_rate = self.dropout_rate + else: + dropout_rate = 0. + + # Identify elements with all zeros as padding, and create bias to mask + # out padding elements in self attention. + encoder_padding = common_attention.embedding_to_padding(inputs) + encoder_self_attention_bias = ( + common_attention.attention_bias_ignore_padding(encoder_padding)) + + x = inputs + for layer_num in range(self.num_layers): + with tf.variable_scope('layer_{}'.format(layer_num)): + + # Multihead self-attention from Tensor2Tensor. + res = x + if self.memory_efficient: + res = multihead_self_attention_memory_efficient( + res, + bias=encoder_self_attention_bias, + num_heads=self.num_heads, + head_size=self.hidden_size // self.num_heads, + forget=True if is_training else False, + name='self_attention' + ) + else: + if self.layer_norm: + res = common_layers.layer_norm(res, name='self_attention') + res = common_attention.multihead_attention( + res, + memory_antecedent=None, + bias=encoder_self_attention_bias, + total_key_depth=self.hidden_size, + total_value_depth=self.hidden_size, + output_depth=self.hidden_size, + num_heads=self.num_heads, + dropout_rate=0., + make_image_summary=False, + name='self_attention') + if self.re_zero: + res *= tf.get_variable('self_attention/alpha', initializer=0.) + if dropout_rate: + res = tf.nn.dropout(res, rate=dropout_rate) + x += res + + # MLP + res = x + if self.layer_norm: + res = common_layers.layer_norm(res, name='fc') + res = tf.layers.dense( + res, self.fc_size, activation=tf.nn.relu, name='fc_1') + res = tf.layers.dense(res, self.hidden_size, name='fc_2') + if self.re_zero: + res *= tf.get_variable('fc/alpha', initializer=0.) + if dropout_rate: + res = tf.nn.dropout(res, rate=dropout_rate) + x += res + + if self.layer_norm: + output = common_layers.layer_norm(x, name='output') + else: + output = x + return output + + +class TransformerDecoder(snt.AbstractModule): + """Transformer decoder. + + Sonnet Transformer decoder module as described in Vaswani et al. 2017. Uses + the Tensor2Tensor multihead_attention function for masked self attention, and + non-masked cross attention attention. Layer norm is applied inside the + residual path as in sparse transformers (Child 2019). + + This module expects inputs to be already embedded, and does not + add position embeddings. + """ + + def __init__(self, + hidden_size=256, + fc_size=1024, + num_heads=4, + layer_norm=True, + num_layers=8, + dropout_rate=0.2, + re_zero=True, + memory_efficient=False, + name='transformer_decoder'): + """Initializes TransformerDecoder. + + Args: + hidden_size: Size of embedding vectors. + fc_size: Size of fully connected layer. + num_heads: Number of attention heads. + layer_norm: If True, apply layer normalization. If mem_efficient_attention + is True, then layer norm is always applied. + num_layers: Number of Transformer blocks, where each block contains a + multi-head attention layer and a MLP. + dropout_rate: Dropout rate applied immediately after the ReLU in each + fully-connected layer. + re_zero: If True, alpha scale residuals with zero init. + memory_efficient: If True, recompute gradients for memory savings. + name: Name of variable scope + """ + super(TransformerDecoder, self).__init__(name=name) + self.hidden_size = hidden_size + self.num_heads = num_heads + self.layer_norm = layer_norm + self.fc_size = fc_size + self.num_layers = num_layers + self.dropout_rate = dropout_rate + self.re_zero = re_zero + self.memory_efficient = memory_efficient + + def _build(self, + inputs, + sequential_context_embeddings=None, + is_training=False, + cache=None): + """Passes inputs through Transformer decoder network. + + Args: + inputs: Tensor of shape [batch_size, sequence_length, embed_size]. Zero + embeddings are masked in self-attention. + sequential_context_embeddings: Optional tensor with global context + (e.g image embeddings) of shape + [batch_size, context_seq_length, context_embed_size]. + is_training: If True, dropout is applied. + cache: Optional dict containing tensors which are the results of previous + attentions, used for fast decoding. Expects the dict to contain two + keys ('k' and 'v'), for the initial call the values for these keys + should be empty Tensors of the appropriate shape. + 'k' [batch_size, 0, key_channels] 'v' [batch_size, 0, value_channels] + + Returns: + output: Tensor of shape [batch_size, sequence_length, embed_size]. + """ + if is_training: + dropout_rate = self.dropout_rate + else: + dropout_rate = 0. + + # create bias to mask future elements for causal self-attention. + seq_length = tf.shape(inputs)[1] + decoder_self_attention_bias = common_attention.attention_bias_lower_triangle( + seq_length) + + # If using sequential_context, identify elements with all zeros as padding, + # and create bias to mask out padding elements in self attention. + if sequential_context_embeddings is not None: + encoder_padding = common_attention.embedding_to_padding( + sequential_context_embeddings) + encoder_decoder_attention_bias = ( + common_attention.attention_bias_ignore_padding(encoder_padding)) + + x = inputs + for layer_num in range(self.num_layers): + with tf.variable_scope('layer_{}'.format(layer_num)): + + # If using cached decoding, access cache for current layer, and create + # bias that enables un-masked attention into the cache + if cache is not None: + layer_cache = cache[layer_num] + layer_decoder_bias = tf.zeros([1, 1, 1, 1]) + # Otherwise use standard masked bias + else: + layer_cache = None + layer_decoder_bias = decoder_self_attention_bias + + # Multihead self-attention from Tensor2Tensor. + res = x + if self.memory_efficient: + res = multihead_self_attention_memory_efficient( + res, + bias=layer_decoder_bias, + cache=layer_cache, + num_heads=self.num_heads, + head_size=self.hidden_size // self.num_heads, + forget=True if is_training else False, + name='self_attention' + ) + else: + if self.layer_norm: + res = common_layers.layer_norm(res, name='self_attention') + res = common_attention.multihead_attention( + res, + memory_antecedent=None, + bias=layer_decoder_bias, + total_key_depth=self.hidden_size, + total_value_depth=self.hidden_size, + output_depth=self.hidden_size, + num_heads=self.num_heads, + cache=layer_cache, + dropout_rate=0., + make_image_summary=False, + name='self_attention') + if self.re_zero: + res *= tf.get_variable('self_attention/alpha', initializer=0.) + if dropout_rate: + res = tf.nn.dropout(res, rate=dropout_rate) + x += res + + # Optional cross attention into sequential context + if sequential_context_embeddings is not None: + res = x + if self.layer_norm: + res = common_layers.layer_norm(res, name='cross_attention') + res = common_attention.multihead_attention( + res, + memory_antecedent=sequential_context_embeddings, + bias=encoder_decoder_attention_bias, + total_key_depth=self.hidden_size, + total_value_depth=self.hidden_size, + output_depth=self.hidden_size, + num_heads=self.num_heads, + dropout_rate=0., + make_image_summary=False, + name='cross_attention') + if self.re_zero: + res *= tf.get_variable('cross_attention/alpha', initializer=0.) + if dropout_rate: + res = tf.nn.dropout(res, rate=dropout_rate) + x += res + + # FC layers + res = x + if self.layer_norm: + res = common_layers.layer_norm(res, name='fc') + res = tf.layers.dense( + res, self.fc_size, activation=tf.nn.relu, name='fc_1') + res = tf.layers.dense(res, self.hidden_size, name='fc_2') + if self.re_zero: + res *= tf.get_variable('fc/alpha', initializer=0.) + if dropout_rate: + res = tf.nn.dropout(res, rate=dropout_rate) + x += res + + if self.layer_norm: + output = common_layers.layer_norm(x, name='output') + else: + output = x + return output + + def create_init_cache(self, batch_size): + """Creates empty cache dictionary for use in fast decoding.""" + + def compute_cache_shape_invariants(tensor): + """Helper function to get dynamic shapes for cache tensors.""" + shape_list = tensor.shape.as_list() + if len(shape_list) == 4: + return tf.TensorShape( + [shape_list[0], shape_list[1], None, shape_list[3]]) + elif len(shape_list) == 3: + return tf.TensorShape([shape_list[0], None, shape_list[2]]) + + # Build cache + k = common_attention.split_heads( + tf.zeros([batch_size, 0, self.hidden_size]), self.num_heads) + v = common_attention.split_heads( + tf.zeros([batch_size, 0, self.hidden_size]), self.num_heads) + cache = [{'k': k, 'v': v} for _ in range(self.num_layers)] + shape_invariants = tf.nest.map_structure( + compute_cache_shape_invariants, cache) + return cache, shape_invariants + + +def conv_residual_block(inputs, + output_channels=None, + downsample=False, + kernel_size=3, + re_zero=True, + dropout_rate=0., + name='conv_residual_block'): + """Convolutional block with residual connections for 2D or 3D inputs. + + Args: + inputs: Input tensor of shape [batch_size, height, width, channels] or + [batch_size, height, width, depth, channels]. + output_channels: Number of output channels. + downsample: If True, downsample by 1/2 in this block. + kernel_size: Spatial size of convolutional kernels. + re_zero: If True, alpha scale residuals with zero init. + dropout_rate: Dropout rate applied after second ReLU in residual path. + name: Name for variable scope. + + Returns: + outputs: Output tensor of shape [batch_size, height, width, output_channels] + or [batch_size, height, width, depth, output_channels]. + """ + with tf.variable_scope(name): + input_shape = inputs.get_shape().as_list() + num_dims = len(input_shape) - 2 + + if num_dims == 2: + conv = tf.layers.conv2d + elif num_dims == 3: + conv = tf.layers.conv3d + + input_channels = input_shape[-1] + if output_channels is None: + output_channels = input_channels + if downsample: + shortcut = conv( + inputs, + filters=output_channels, + strides=2, + kernel_size=kernel_size, + padding='same', + name='conv_shortcut') + else: + shortcut = inputs + + res = inputs + res = tf.nn.relu(res) + res = conv( + res, filters=input_channels, kernel_size=kernel_size, padding='same', + name='conv_1') + + res = tf.nn.relu(res) + if dropout_rate: + res = tf.nn.dropout(res, rate=dropout_rate) + if downsample: + out_strides = 2 + else: + out_strides = 1 + res = conv( + res, + filters=output_channels, + kernel_size=kernel_size, + padding='same', + strides=out_strides, + name='conv_2') + if re_zero: + res *= tf.get_variable('alpha', initializer=0.) + return shortcut + res + + +class ResNet(snt.AbstractModule): + """ResNet architecture for 2D image or 3D voxel inputs.""" + + def __init__(self, + num_dims, + hidden_sizes=(64, 256), + num_blocks=(2, 2), + dropout_rate=0.1, + re_zero=True, + name='res_net'): + """Initializes ResNet. + + Args: + num_dims: Number of spatial dimensions. 2 for images or 3 for voxels. + hidden_sizes: Sizes of hidden layers in resnet blocks. + num_blocks: Number of resnet blocks at each size. + dropout_rate: Dropout rate applied immediately after the ReLU in each + fully-connected layer. + re_zero: If True, alpha scale residuals with zero init. + name: Name of variable scope + """ + super(ResNet, self).__init__(name=name) + self.num_dims = num_dims + self.hidden_sizes = hidden_sizes + self.num_blocks = num_blocks + self.dropout_rate = dropout_rate + self.re_zero = re_zero + + def _build(self, inputs, is_training=False): + """Passes inputs through resnet. + + Args: + inputs: Tensor of shape [batch_size, height, width, channels] or + [batch_size, height, width, depth, channels]. + is_training: If True, dropout is applied. + + Returns: + output: Tensor of shape [batch_size, height, width, depth, output_size]. + """ + if is_training: + dropout_rate = self.dropout_rate + else: + dropout_rate = 0. + + # Initial projection with large kernel as in original resnet architecture + if self.num_dims == 3: + conv = tf.layers.conv3d + elif self.num_dims == 2: + conv = tf.layers.conv2d + x = conv( + inputs, + filters=self.hidden_sizes[0], + kernel_size=7, + strides=2, + padding='same', + name='conv_input') + + if self.num_dims == 2: + x = tf.layers.max_pooling2d( + x, strides=2, pool_size=3, padding='same', name='pool_input') + + for d, (hidden_size, + blocks) in enumerate(zip(self.hidden_sizes, self.num_blocks)): + + with tf.variable_scope('resolution_{}'.format(d)): + + # Downsample at the start of each collection of blocks + x = conv_residual_block( + x, + downsample=False if d == 0 else True, + dropout_rate=dropout_rate, + output_channels=hidden_size, + re_zero=self.re_zero, + name='block_1_downsample') + for i in range(blocks - 1): + x = conv_residual_block( + x, + dropout_rate=dropout_rate, + output_channels=hidden_size, + re_zero=self.re_zero, + name='block_{}'.format(i + 2)) + return x + + +class VertexModel(snt.AbstractModule): + """Autoregressive generative model of quantized mesh vertices. + + Operates on flattened vertex sequences with a stopping token: + + [z_0, y_0, x_0, z_1, y_1, x_1, ..., z_n, y_n, z_n, STOP] + + Input vertex coordinates are embedded and tagged with learned coordinate and + position indicators. A transformer decoder outputs logits for a quantized + vertex distribution. + """ + + def __init__(self, + decoder_config, + quantization_bits, + class_conditional=False, + num_classes=55, + max_num_input_verts=2500, + use_discrete_embeddings=True, + name='vertex_model'): + """Initializes VertexModel. + + Args: + decoder_config: Dictionary with TransformerDecoder config + quantization_bits: Number of quantization used in mesh preprocessing. + class_conditional: If True, then condition on learned class embeddings. + num_classes: Number of classes to condition on. + max_num_input_verts: Maximum number of vertices. Used for learned position + embeddings. + use_discrete_embeddings: If True, use discrete rather than continuous + vertex embeddings. + name: Name of variable scope + """ + super(VertexModel, self).__init__(name=name) + self.embedding_dim = decoder_config['hidden_size'] + self.class_conditional = class_conditional + self.num_classes = num_classes + self.max_num_input_verts = max_num_input_verts + self.quantization_bits = quantization_bits + self.use_discrete_embeddings = use_discrete_embeddings + + with self._enter_variable_scope(): + self.decoder = TransformerDecoder(**decoder_config) + + @snt.reuse_variables + def _embed_class_label(self, labels): + """Embeds class label with learned embedding matrix.""" + init_dict = {'embeddings': tf.glorot_uniform_initializer} + return snt.Embed( + vocab_size=self.num_classes, + embed_dim=self.embedding_dim, + initializers=init_dict, + densify_gradients=True, + name='class_label')(labels) + + @snt.reuse_variables + def _prepare_context(self, context, is_training=False): + """Prepare class label context.""" + if self.class_conditional: + global_context_embedding = self._embed_class_label(context['class_label']) + else: + global_context_embedding = None + return global_context_embedding, None + + @snt.reuse_variables + def _embed_inputs(self, vertices, global_context_embedding=None): + """Embeds flat vertices and adds position and coordinate information.""" + # Dequantize inputs and get shapes + input_shape = tf.shape(vertices) + batch_size, seq_length = input_shape[0], input_shape[1] + + # Coord indicators (x, y, z) + coord_embeddings = snt.Embed( + vocab_size=3, + embed_dim=self.embedding_dim, + initializers={'embeddings': tf.glorot_uniform_initializer}, + densify_gradients=True, + name='coord_embeddings')(tf.mod(tf.range(seq_length), 3)) + + # Position embeddings + pos_embeddings = snt.Embed( + vocab_size=self.max_num_input_verts, + embed_dim=self.embedding_dim, + initializers={'embeddings': tf.glorot_uniform_initializer}, + densify_gradients=True, + name='coord_embeddings')(tf.floordiv(tf.range(seq_length), 3)) + + # Discrete vertex value embeddings + if self.use_discrete_embeddings: + vert_embeddings = snt.Embed( + vocab_size=2**self.quantization_bits + 1, + embed_dim=self.embedding_dim, + initializers={'embeddings': tf.glorot_uniform_initializer}, + densify_gradients=True, + name='value_embeddings')(vertices) + # Continuous vertex value embeddings + else: + vert_embeddings = tf.layers.dense( + dequantize_verts(vertices[Ellipsis, None], self.quantization_bits), + self.embedding_dim, + use_bias=True, + name='value_embeddings') + + # Step zero embeddings + if global_context_embedding is None: + zero_embed = tf.get_variable( + 'embed_zero', shape=[1, 1, self.embedding_dim]) + zero_embed_tiled = tf.tile(zero_embed, [batch_size, 1, 1]) + else: + zero_embed_tiled = global_context_embedding[:, None] + + # Aggregate embeddings + embeddings = vert_embeddings + (coord_embeddings + pos_embeddings)[None] + embeddings = tf.concat([zero_embed_tiled, embeddings], axis=1) + + return embeddings + + @snt.reuse_variables + def _project_to_logits(self, inputs): + """Projects transformer outputs to logits for predictive distribution.""" + return tf.layers.dense( + inputs, + 2**self.quantization_bits + 1, # + 1 for stopping token + use_bias=True, + kernel_initializer=tf.zeros_initializer(), + name='project_to_logits') + + @snt.reuse_variables + def _create_dist(self, + vertices, + global_context_embedding=None, + sequential_context_embeddings=None, + temperature=1., + top_k=0, + top_p=1., + is_training=False, + cache=None): + """Outputs categorical dist for quantized vertex coordinates.""" + + # Embed inputs + decoder_inputs = self._embed_inputs(vertices, global_context_embedding) + if cache is not None: + decoder_inputs = decoder_inputs[:, -1:] + + # pass through decoder + outputs = self.decoder( + decoder_inputs, cache=cache, + sequential_context_embeddings=sequential_context_embeddings, + is_training=is_training) + + # Get logits and optionally process for sampling + logits = self._project_to_logits(outputs) + logits /= temperature + logits = top_k_logits(logits, top_k) + logits = top_p_logits(logits, top_p) + cat_dist = tfd.Categorical(logits=logits) + return cat_dist + + def _build(self, batch, is_training=False): + """Pass batch through vertex model and get log probabilities under model. + + Args: + batch: Dictionary containing: + 'vertices_flat': int32 vertex tensors of shape [batch_size, seq_length]. + is_training: If True, use dropout. + + Returns: + pred_dist: tfd.Categorical predictive distribution with batch shape + [batch_size, seq_length]. + """ + global_context, seq_context = self._prepare_context( + batch, is_training=is_training) + pred_dist = self._create_dist( + batch['vertices_flat'][:, :-1], # Last element not used for preds + global_context_embedding=global_context, + sequential_context_embeddings=seq_context, + is_training=is_training) + return pred_dist + + def sample(self, + num_samples, + context=None, + max_sample_length=None, + temperature=1., + top_k=0, + top_p=1., + recenter_verts=True, + only_return_complete=True): + """Autoregressive sampling with caching. + + Args: + num_samples: Number of samples to produce. + context: Dictionary of context, such as class labels. See _prepare_context + for details. + max_sample_length: Maximum length of sampled vertex sequences. Sequences + that do not complete are truncated. + temperature: Scalar softmax temperature > 0. + top_k: Number of tokens to keep for top-k sampling. + top_p: Proportion of probability mass to keep for top-p sampling. + recenter_verts: If True, center vertex samples around origin. This should + be used if model is trained using shift augmentations. + only_return_complete: If True, only return completed samples. Otherwise + return all samples along with completed indicator. + + Returns: + outputs: Output dictionary with fields: + 'completed': Boolean tensor of shape [num_samples]. If True then + corresponding sample completed within max_sample_length. + 'vertices': Tensor of samples with shape [num_samples, num_verts, 3]. + 'num_vertices': Tensor indicating number of vertices for each example + in padded vertex samples. + 'vertices_mask': Tensor of shape [num_samples, num_verts] that masks + corresponding invalid elements in 'vertices'. + """ + # Obtain context for decoder + global_context, seq_context = self._prepare_context( + context, is_training=False) + + # num_samples is the minimum value of num_samples and the batch size of + # context inputs (if present). + if global_context is not None: + num_samples = tf.minimum(num_samples, tf.shape(global_context)[0]) + global_context = global_context[:num_samples] + if seq_context is not None: + seq_context = seq_context[:num_samples] + elif seq_context is not None: + num_samples = tf.minimum(num_samples, tf.shape(seq_context)[0]) + seq_context = seq_context[:num_samples] + + def _loop_body(i, samples, cache): + """While-loop body for autoregression calculation.""" + cat_dist = self._create_dist( + samples, + global_context_embedding=global_context, + sequential_context_embeddings=seq_context, + cache=cache, + temperature=temperature, + top_k=top_k, + top_p=top_p) + next_sample = cat_dist.sample() + samples = tf.concat([samples, next_sample], axis=1) + return i + 1, samples, cache + + def _stopping_cond(i, samples, cache): + """Stopping condition for sampling while-loop.""" + del i, cache # Unused + return tf.reduce_any(tf.reduce_all(tf.not_equal(samples, 0), axis=-1)) + + # Initial values for loop variables + samples = tf.zeros([num_samples, 0], dtype=tf.int32) + max_sample_length = max_sample_length or self.max_num_input_verts + cache, cache_shape_invariants = self.decoder.create_init_cache(num_samples) + _, v, _ = tf.while_loop( + cond=_stopping_cond, + body=_loop_body, + loop_vars=(0, samples, cache), + shape_invariants=(tf.TensorShape([]), tf.TensorShape([None, None]), + cache_shape_invariants), + maximum_iterations=max_sample_length * 3 + 1, + back_prop=False, + parallel_iterations=1) + + # Check if samples completed. Samples are complete if the stopping token + # is produced. + completed = tf.reduce_any(tf.equal(v, 0), axis=-1) + + # Get the number of vertices in the sample. This requires finding the + # index of the stopping token. For complete samples use to argmax to get + # first nonzero index. + stop_index_completed = tf.argmax( + tf.cast(tf.equal(v, 0), tf.int32), axis=-1, output_type=tf.int32) + # For incomplete samples the stopping index is just the maximum index. + stop_index_incomplete = ( + max_sample_length * 3 * tf.ones_like(stop_index_completed)) + stop_index = tf.where( + completed, stop_index_completed, stop_index_incomplete) + num_vertices = tf.floordiv(stop_index, 3) + + # Convert to 3D vertices by reshaping and re-ordering x -> y -> z + v = v[:, :(tf.reduce_max(num_vertices) * 3)] - 1 + verts_dequantized = dequantize_verts(v, self.quantization_bits) + vertices = tf.reshape(verts_dequantized, [num_samples, -1, 3]) + vertices = tf.stack( + [vertices[Ellipsis, 2], vertices[Ellipsis, 1], vertices[Ellipsis, 0]], axis=-1) + + # Pad samples to max sample length. This is required in order to concatenate + # Samples across different replicator instances. Pad with stopping tokens + # for incomplete samples. + pad_size = max_sample_length - tf.shape(vertices)[1] + vertices = tf.pad(vertices, [[0, 0], [0, pad_size], [0, 0]]) + + # 3D Vertex mask + vertices_mask = tf.cast( + tf.range(max_sample_length)[None] < num_vertices[:, None], tf.float32) + + if recenter_verts: + vert_max = tf.reduce_max( + vertices - 1e10 * (1. - vertices_mask)[Ellipsis, None], axis=1, + keepdims=True) + vert_min = tf.reduce_min( + vertices + 1e10 * (1. - vertices_mask)[Ellipsis, None], axis=1, + keepdims=True) + vert_centers = 0.5 * (vert_max + vert_min) + vertices -= vert_centers + vertices *= vertices_mask[Ellipsis, None] + + if only_return_complete: + vertices = tf.boolean_mask(vertices, completed) + num_vertices = tf.boolean_mask(num_vertices, completed) + vertices_mask = tf.boolean_mask(vertices_mask, completed) + completed = tf.boolean_mask(completed, completed) + + # Outputs + outputs = { + 'completed': completed, + 'vertices': vertices, + 'num_vertices': num_vertices, + 'vertices_mask': vertices_mask, + } + return outputs + + +class ImageToVertexModel(VertexModel): + """Generative model of quantized mesh vertices with image conditioning. + + Operates on flattened vertex sequences with a stopping token: + + [z_0, y_0, x_0, z_1, y_1, x_1, ..., z_n, y_n, z_n, STOP] + + Input vertex coordinates are embedded and tagged with learned coordinate and + position indicators. A transformer decoder outputs logits for a quantized + vertex distribution. Image inputs are encoded and used to condition the + vertex decoder. + """ + + def __init__(self, + res_net_config, + decoder_config, + quantization_bits, + use_discrete_embeddings=True, + max_num_input_verts=2500, + name='image_to_vertex_model'): + """Initializes VoxelToVertexModel. + + Args: + res_net_config: Dictionary with ResNet config. + decoder_config: Dictionary with TransformerDecoder config. + quantization_bits: Number of quantization used in mesh preprocessing. + use_discrete_embeddings: If True, use discrete rather than continuous + vertex embeddings. + max_num_input_verts: Maximum number of vertices. Used for learned position + embeddings. + name: Name of variable scope + """ + super(ImageToVertexModel, self).__init__( + decoder_config=decoder_config, + quantization_bits=quantization_bits, + max_num_input_verts=max_num_input_verts, + use_discrete_embeddings=use_discrete_embeddings, + name=name) + + with self._enter_variable_scope(): + self.res_net = ResNet(num_dims=2, **res_net_config) + + @snt.reuse_variables + def _prepare_context(self, context, is_training=False): + + # Pass images through encoder + image_embeddings = self.res_net( + context['image'] - 0.5, is_training=is_training) + + # Add 2D coordinate grid embedding + processed_image_resolution = tf.shape(image_embeddings)[1] + x = tf.linspace(-1., 1., processed_image_resolution) + image_coords = tf.stack(tf.meshgrid(x, x), axis=-1) + image_coord_embeddings = tf.layers.dense( + image_coords, + self.embedding_dim, + use_bias=True, + name='image_coord_embeddings') + image_embeddings += image_coord_embeddings[None] + + # Reshape spatial grid to sequence + batch_size = tf.shape(image_embeddings)[0] + sequential_context_embedding = tf.reshape( + image_embeddings, [batch_size, -1, self.embedding_dim]) + + return None, sequential_context_embedding + + +class VoxelToVertexModel(VertexModel): + """Generative model of quantized mesh vertices with voxel conditioning. + + Operates on flattened vertex sequences with a stopping token: + + [z_0, y_0, x_0, z_1, y_1, x_1, ..., z_n, y_n, z_n, STOP] + + Input vertex coordinates are embedded and tagged with learned coordinate and + position indicators. A transformer decoder outputs logits for a quantized + vertex distribution. Image inputs are encoded and used to condition the + vertex decoder. + """ + + def __init__(self, + res_net_config, + decoder_config, + quantization_bits, + use_discrete_embeddings=True, + max_num_input_verts=2500, + name='voxel_to_vertex_model'): + """Initializes VoxelToVertexModel. + + Args: + res_net_config: Dictionary with ResNet config. + decoder_config: Dictionary with TransformerDecoder config. + quantization_bits: Integer number of bits used for vertex quantization. + use_discrete_embeddings: If True, use discrete rather than continuous + vertex embeddings. + max_num_input_verts: Maximum number of vertices. Used for learned position + embeddings. + name: Name of variable scope + """ + super(VoxelToVertexModel, self).__init__( + decoder_config=decoder_config, + quantization_bits=quantization_bits, + max_num_input_verts=max_num_input_verts, + use_discrete_embeddings=use_discrete_embeddings, + name=name) + + with self._enter_variable_scope(): + self.res_net = ResNet(num_dims=3, **res_net_config) + + @snt.reuse_variables + def _prepare_context(self, context, is_training=False): + + # Embed binary input voxels + voxel_embeddings = snt.Embed( + vocab_size=2, + embed_dim=self.pre_embed_dim, + initializers={'embeddings': tf.glorot_uniform_initializer}, + densify_gradients=True, + name='voxel_embeddings')(context['voxels']) + + # Pass embedded voxels through voxel encoder + voxel_embeddings = self.res_net( + voxel_embeddings, is_training=is_training) + + # Add 3D coordinate grid embedding + processed_voxel_resolution = tf.shape(voxel_embeddings)[1] + x = tf.linspace(-1., 1., processed_voxel_resolution) + voxel_coords = tf.stack(tf.meshgrid(x, x, x), axis=-1) + voxel_coord_embeddings = tf.layers.dense( + voxel_coords, + self.embedding_dim, + use_bias=True, + name='voxel_coord_embeddings') + voxel_embeddings += voxel_coord_embeddings[None] + + # Reshape spatial grid to sequence + batch_size = tf.shape(voxel_embeddings)[0] + sequential_context_embedding = tf.reshape( + voxel_embeddings, [batch_size, -1, self.embedding_dim]) + + return None, sequential_context_embedding + + +class FaceModel(snt.AbstractModule): + """Autoregressive generative model of n-gon meshes. + + Operates on sets of input vertices as well as flattened face sequences with + new face and stopping tokens: + + [f_0^0, f_0^1, f_0^2, NEW, f_1^0, f_1^1, ..., STOP] + + Input vertices are encoded using a Transformer encoder. + + Input face sequences are embedded and tagged with learned position indicators, + as well as their corresponding vertex embeddings. A transformer decoder + outputs a pointer which is compared to each vertex embedding to obtain a + distribution over vertex indices. + """ + + def __init__(self, + encoder_config, + decoder_config, + class_conditional=True, + num_classes=55, + decoder_cross_attention=True, + use_discrete_vertex_embeddings=True, + quantization_bits=8, + max_seq_length=5000, + name='face_model'): + """Initializes FaceModel. + + Args: + encoder_config: Dictionary with TransformerEncoder config. + decoder_config: Dictionary with TransformerDecoder config. + class_conditional: If True, then condition on learned class embeddings. + num_classes: Number of classes to condition on. + decoder_cross_attention: If True, the use cross attention from decoder + querys into encoder outputs. + use_discrete_vertex_embeddings: If True, use discrete vertex embeddings. + quantization_bits: Number of quantization bits for discrete vertex + embeddings. + max_seq_length: Maximum face sequence length. Used for learned position + embeddings. + name: Name of variable scope + """ + super(FaceModel, self).__init__(name=name) + self.embedding_dim = decoder_config['hidden_size'] + self.class_conditional = class_conditional + self.num_classes = num_classes + self.max_seq_length = max_seq_length + self.decoder_cross_attention = decoder_cross_attention + self.use_discrete_vertex_embeddings = use_discrete_vertex_embeddings + self.quantization_bits = quantization_bits + + with self._enter_variable_scope(): + self.decoder = TransformerDecoder(**decoder_config) + self.encoder = TransformerEncoder(**encoder_config) + + @snt.reuse_variables + def _embed_class_label(self, labels): + """Embeds class label with learned embedding matrix.""" + init_dict = {'embeddings': tf.glorot_uniform_initializer} + return snt.Embed( + vocab_size=self.num_classes, + embed_dim=self.embedding_dim, + initializers=init_dict, + densify_gradients=True, + name='class_label')(labels) + + @snt.reuse_variables + def _prepare_context(self, context, is_training=False): + """Prepare class label and vertex context.""" + if self.class_conditional: + global_context_embedding = self._embed_class_label(context['class_label']) + else: + global_context_embedding = None + vertex_embeddings = self._embed_vertices( + context['vertices'], context['vertices_mask'], + is_training=is_training) + if self.decoder_cross_attention: + sequential_context_embeddings = ( + vertex_embeddings * + tf.pad(context['vertices_mask'], [[0, 0], [2, 0]], + constant_values=1)[Ellipsis, None]) + else: + sequential_context_embeddings = None + return (vertex_embeddings, global_context_embedding, + sequential_context_embeddings) + + @snt.reuse_variables + def _embed_vertices(self, vertices, vertices_mask, is_training=False): + """Embeds vertices with transformer encoder.""" + # num_verts = tf.shape(vertices)[1] + if self.use_discrete_vertex_embeddings: + vertex_embeddings = 0. + verts_quantized = quantize_verts(vertices, self.quantization_bits) + for c in range(3): + vertex_embeddings += snt.Embed( + vocab_size=256, + embed_dim=self.embedding_dim, + initializers={'embeddings': tf.glorot_uniform_initializer}, + densify_gradients=True, + name='coord_{}'.format(c))(verts_quantized[Ellipsis, c]) + else: + vertex_embeddings = tf.layers.dense( + vertices, self.embedding_dim, use_bias=True, name='vertex_embeddings') + vertex_embeddings *= vertices_mask[Ellipsis, None] + + # Pad vertex embeddings with learned embeddings for stopping and new face + # tokens + stopping_embeddings = tf.get_variable( + 'stopping_embeddings', shape=[1, 2, self.embedding_dim]) + stopping_embeddings = tf.tile(stopping_embeddings, + [tf.shape(vertices)[0], 1, 1]) + vertex_embeddings = tf.concat( + [stopping_embeddings, vertex_embeddings], axis=1) + + # Pass through Transformer encoder + vertex_embeddings = self.encoder(vertex_embeddings, is_training=is_training) + return vertex_embeddings + + @snt.reuse_variables + def _embed_inputs(self, faces_long, vertex_embeddings, + global_context_embedding=None): + """Embeds face sequences and adds within and between face positions.""" + + # Face value embeddings are gathered vertex embeddings + face_embeddings = tf.gather(vertex_embeddings, faces_long, batch_dims=1) + + # Position embeddings + pos_embeddings = snt.Embed( + vocab_size=self.max_seq_length, + embed_dim=self.embedding_dim, + initializers={'embeddings': tf.glorot_uniform_initializer}, + densify_gradients=True, + name='coord_embeddings')(tf.range(tf.shape(faces_long)[1])) + + # Step zero embeddings + batch_size = tf.shape(face_embeddings)[0] + if global_context_embedding is None: + zero_embed = tf.get_variable( + 'embed_zero', shape=[1, 1, self.embedding_dim]) + zero_embed_tiled = tf.tile(zero_embed, [batch_size, 1, 1]) + else: + zero_embed_tiled = global_context_embedding[:, None] + + # Aggregate embeddings + embeddings = face_embeddings + pos_embeddings[None] + embeddings = tf.concat([zero_embed_tiled, embeddings], axis=1) + + return embeddings + + @snt.reuse_variables + def _project_to_pointers(self, inputs): + """Projects transformer outputs to pointer vectors.""" + return tf.layers.dense( + inputs, + self.embedding_dim, + use_bias=True, + kernel_initializer=tf.zeros_initializer(), + name='project_to_pointers' + ) + + @snt.reuse_variables + def _create_dist(self, + vertex_embeddings, + vertices_mask, + faces_long, + global_context_embedding=None, + sequential_context_embeddings=None, + temperature=1., + top_k=0, + top_p=1., + is_training=False, + cache=None): + """Outputs categorical dist for vertex indices.""" + + # Embed inputs + decoder_inputs = self._embed_inputs( + faces_long, vertex_embeddings, global_context_embedding) + + # Pass through Transformer decoder + if cache is not None: + decoder_inputs = decoder_inputs[:, -1:] + decoder_outputs = self.decoder( + decoder_inputs, + cache=cache, + sequential_context_embeddings=sequential_context_embeddings, + is_training=is_training) + + # Get pointers + pred_pointers = self._project_to_pointers(decoder_outputs) + + # Get logits and mask + logits = tf.matmul(pred_pointers, vertex_embeddings, transpose_b=True) + logits /= tf.sqrt(float(self.embedding_dim)) + f_verts_mask = tf.pad( + vertices_mask, [[0, 0], [2, 0]], constant_values=1.)[:, None] + logits *= f_verts_mask + logits -= (1. - f_verts_mask) * 1e9 + logits /= temperature + logits = top_k_logits(logits, top_k) + logits = top_p_logits(logits, top_p) + return tfd.Categorical(logits=logits) + + def _build(self, batch, is_training=False): + """Pass batch through face model and get log probabilities. + + Args: + batch: Dictionary containing: + 'vertices_dequantized': Tensor of shape [batch_size, num_vertices, 3]. + 'faces': int32 tensor of shape [batch_size, seq_length] with flattened + faces. + 'vertices_mask': float32 tensor with shape + [batch_size, num_vertices] that masks padded elements in 'vertices'. + is_training: If True, use dropout. + + Returns: + pred_dist: tfd.Categorical predictive distribution with batch shape + [batch_size, seq_length]. + """ + vertex_embeddings, global_context, seq_context = self._prepare_context( + batch, is_training=is_training) + pred_dist = self._create_dist( + vertex_embeddings, + batch['vertices_mask'], + batch['faces'][:, :-1], + global_context_embedding=global_context, + sequential_context_embeddings=seq_context, + is_training=is_training) + return pred_dist + + def sample(self, + context, + max_sample_length=None, + temperature=1., + top_k=0, + top_p=1., + only_return_complete=True): + """Sample from face model using caching. + + Args: + context: Dictionary of context, including 'vertices' and 'vertices_mask'. + See _prepare_context for details. + max_sample_length: Maximum length of sampled vertex sequences. Sequences + that do not complete are truncated. + temperature: Scalar softmax temperature > 0. + top_k: Number of tokens to keep for top-k sampling. + top_p: Proportion of probability mass to keep for top-p sampling. + only_return_complete: If True, only return completed samples. Otherwise + return all samples along with completed indicator. + + Returns: + outputs: Output dictionary with fields: + 'completed': Boolean tensor of shape [num_samples]. If True then + corresponding sample completed within max_sample_length. + 'faces': Tensor of samples with shape [num_samples, num_verts, 3]. + 'num_face_indices': Tensor indicating number of vertices for each + example in padded vertex samples. + """ + vertex_embeddings, global_context, seq_context = self._prepare_context( + context, is_training=False) + num_samples = tf.shape(vertex_embeddings)[0] + + def _loop_body(i, samples, cache): + """While-loop body for autoregression calculation.""" + pred_dist = self._create_dist( + vertex_embeddings, + context['vertices_mask'], + samples, + global_context_embedding=global_context, + sequential_context_embeddings=seq_context, + cache=cache, + temperature=temperature, + top_k=top_k, + top_p=top_p) + next_sample = pred_dist.sample()[:, -1:] + samples = tf.concat([samples, next_sample], axis=1) + return i + 1, samples, cache + + def _stopping_cond(i, samples, cache): + """Stopping conditions for autoregressive calculation.""" + del i, cache # Unused + return tf.reduce_any(tf.reduce_all(tf.not_equal(samples, 0), axis=-1)) + + # While loop sampling with caching + samples = tf.zeros([num_samples, 0], dtype=tf.int32) + max_sample_length = max_sample_length or self.max_seq_length + cache, cache_shape_invariants = self.decoder.create_init_cache(num_samples) + _, f, _ = tf.while_loop( + cond=_stopping_cond, + body=_loop_body, + loop_vars=(0, samples, cache), + shape_invariants=(tf.TensorShape([]), tf.TensorShape([None, None]), + cache_shape_invariants), + back_prop=False, + parallel_iterations=1, + maximum_iterations=max_sample_length) + + # Record completed samples + complete_samples = tf.reduce_any(tf.equal(f, 0), axis=-1) + + # Find number of faces + sample_length = tf.shape(f)[-1] + # Get largest new face (1) index as stopping point for incomplete samples. + max_one_ind = tf.reduce_max( + tf.range(sample_length)[None] * tf.cast(tf.equal(f, 1), tf.int32), + axis=-1) + zero_inds = tf.cast( + tf.argmax(tf.cast(tf.equal(f, 0), tf.int32), axis=-1), tf.int32) + num_face_indices = tf.where(complete_samples, zero_inds, max_one_ind) + 1 + + # Mask faces beyond stopping token with zeros + # This mask has a -1 in order to replace the last new face token with zero + faces_mask = tf.cast( + tf.range(sample_length)[None] < num_face_indices[:, None] - 1, tf.int32) + f *= faces_mask + # This is the real mask + faces_mask = tf.cast( + tf.range(sample_length)[None] < num_face_indices[:, None], tf.int32) + + # Pad to maximum size with zeros + pad_size = max_sample_length - sample_length + f = tf.pad(f, [[0, 0], [0, pad_size]]) + + if only_return_complete: + f = tf.boolean_mask(f, complete_samples) + num_face_indices = tf.boolean_mask(num_face_indices, complete_samples) + context = tf.nest.map_structure( + lambda x: tf.boolean_mask(x, complete_samples), context) + complete_samples = tf.boolean_mask(complete_samples, complete_samples) + + # outputs + outputs = { + 'context': context, + 'completed': complete_samples, + 'faces': f, + 'num_face_indices': num_face_indices, + } + return outputs diff --git a/polygen/run.sh b/polygen/run.sh new file mode 100644 index 0000000..3276ec2 --- /dev/null +++ b/polygen/run.sh @@ -0,0 +1,20 @@ +#!/bin/sh +# Copyright 2020 Deepmind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +python3 -m venv polygen +source polygen/bin/activate +pip3 install . +python3 model_test.py +deactivate diff --git a/polygen/sample-pretrained.ipynb b/polygen/sample-pretrained.ipynb new file mode 100644 index 0000000..353febd --- /dev/null +++ b/polygen/sample-pretrained.ipynb @@ -0,0 +1,294 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "lUh_eWpealmh" + }, + "source": [ + "Copyright 2020 DeepMind Technologies Limited\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "kYd9gIfGJYZ8" + }, + "source": [ + "## Clone repo and import dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ux33ZDQ_tqUV" + }, + "outputs": [], + "source": [ + "!pip install tensorflow==1.15\n", + "!pip install dm-sonnet==1.36\n", + "!pip install tensor2tensor==1.14\n", + "\n", + "import time\n", + "import numpy as np\n", + "import tensorflow.compat.v1 as tf\n", + "tf.logging.set_verbosity(tf.logging.ERROR) # Hide TF deprecation messages\n", + "import matplotlib.pyplot as plt\n", + "\n", + "!git clone https://github.com/deepmind/deepmind-research.git deepmind_research\n", + "%cd deepmind_research/polygen\n", + "import modules\n", + "import data_utils" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "YUZNqHTVJbm3" + }, + "source": [ + "## Download pre-trained model weights from Google Cloud Storage" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "LpZjBUmq10gX" + }, + "outputs": [], + "source": [ + "!mkdir /tmp/vertex_model\n", + "!mkdir /tmp/face_model\n", + "!gsutil cp gs://deepmind-research-polygen/vertex_model.tar.gz /tmp/vertex_model/\n", + "!gsutil cp gs://deepmind-research-polygen/face_model.tar.gz /tmp/face_model/\n", + "!tar xvfz /tmp/vertex_model/vertex_model.tar.gz -C /tmp/vertex_model/\n", + "!tar xvfz /tmp/face_model/face_model.tar.gz -C /tmp/face_model/" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vXhMOoxaJ3Xb" + }, + "source": [ + "## Pre-trained model config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "7VGSSS9vJSn-" + }, + "outputs": [], + "source": [ + "vertex_module_config=dict(\n", + " decoder_config=dict(\n", + " hidden_size=512,\n", + " fc_size=2048,\n", + " num_heads=8,\n", + " layer_norm=True,\n", + " num_layers=24,\n", + " dropout_rate=0.4,\n", + " re_zero=True,\n", + " memory_efficient=True\n", + " ),\n", + " quantization_bits=8,\n", + " class_conditional=True,\n", + " max_num_input_verts=5000,\n", + " use_discrete_embeddings=True,\n", + " )\n", + "\n", + "face_module_config=dict(\n", + " encoder_config=dict(\n", + " hidden_size=512,\n", + " fc_size=2048,\n", + " num_heads=8,\n", + " layer_norm=True,\n", + " num_layers=10,\n", + " dropout_rate=0.2,\n", + " re_zero=True,\n", + " memory_efficient=True,\n", + " ),\n", + " decoder_config=dict(\n", + " hidden_size=512,\n", + " fc_size=2048,\n", + " num_heads=8,\n", + " layer_norm=True,\n", + " num_layers=14,\n", + " dropout_rate=0.2,\n", + " re_zero=True,\n", + " memory_efficient=True,\n", + " ),\n", + " class_conditional=False,\n", + " decoder_cross_attention=True,\n", + " use_discrete_vertex_embeddings=True,\n", + " max_seq_length=8000,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "WNXf_XbuKW4S" + }, + "source": [ + "## Generate class-conditional samples\n", + "\n", + "Try varying the `class_id` parameter to generate meshes from different object categories. Good classes to try are tables (49), lamps (30), and cabinets (32). \n", + "\n", + "We can also specify the maximum number of vertices / face indices we want to see in the generated meshes using `max_num_vertices` and `max_num_face_indices`. The code will keep generating batches of samples until there are at least `num_samples_min` complete samples with the required number of vertices / faces.\n", + "\n", + "`top_p_vertex_model` and `top_p_face_model` control how varied the outputs are, with `1.` being the most varied, and `0.` the least varied. `0.9` is a good value for both the vertex and face models.\n", + "\n", + "Sampling should take around 2-5 minutes with a colab GPU using the default settings depending on the object class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "kqKMbPJJu3lk" + }, + "outputs": [], + "source": [ + "class_id = '49) table' #@param ['0) airplane,aeroplane,plane','1) ashcan,trash can,garbage can,wastebin,ash bin,ash-bin,ashbin,dustbin,trash barrel,trash bin','2) bag,traveling bag,travelling bag,grip,suitcase','3) basket,handbasket','4) bathtub,bathing tub,bath,tub','5) bed','6) bench','7) birdhouse','8) bookshelf','9) bottle','10) bowl','11) bus,autobus,coach,charabanc,double-decker,jitney,motorbus,motorcoach,omnibus,passenger vehi','12) cabinet','13) camera,photographic camera','14) can,tin,tin can','15) cap','16) car,auto,automobile,machine,motorcar','17) cellular telephone,cellular phone,cellphone,cell,mobile phone','18) chair','19) clock','20) computer keyboard,keypad','21) dishwasher,dish washer,dishwashing machine','22) display,video display','23) earphone,earpiece,headphone,phone','24) faucet,spigot','25) file,file cabinet,filing cabinet','26) guitar','27) helmet','28) jar','29) knife','30) lamp','31) laptop,laptop computer','32) loudspeaker,speaker,speaker unit,loudspeaker system,speaker system','33) mailbox,letter box','34) microphone,mike','35) microwave,microwave oven','36) motorcycle,bike','37) mug','38) piano,pianoforte,forte-piano','39) pillow','40) pistol,handgun,side arm,shooting iron','41) pot,flowerpot','42) printer,printing machine','43) remote control,remote','44) rifle','45) rocket,projectile','46) skateboard','47) sofa,couch,lounge','48) stove','49) table','50) telephone,phone,telephone set','51) tower','52) train,railroad train','53) vessel,watercraft','54) washer,automatic washer,washing machine']\n", + "num_samples_min = 1 #@param\n", + "num_samples_batch = 8 #@param\n", + "max_num_vertices = 400 #@param\n", + "max_num_face_indices = 2000 #@param\n", + "top_p_vertex_model = 0.9 #@param\n", + "top_p_face_model = 0.9 #@param\n", + "\n", + "tf.reset_default_graph()\n", + "\n", + "# Build models\n", + "vertex_model = modules.VertexModel(**vertex_module_config)\n", + "face_model = modules.FaceModel(**face_module_config)\n", + "\n", + "# Tile out class label to every element in batch\n", + "class_id = int(class_id.split(')')[0])\n", + "vertex_model_context = {'class_label': tf.fill([num_samples_batch,], class_id)}\n", + "vertex_samples = vertex_model.sample(\n", + " num_samples_batch, context=vertex_model_context, \n", + " max_sample_length=max_num_vertices, top_p=top_p_vertex_model, \n", + " recenter_verts=True, only_return_complete=True)\n", + "vertex_model_saver = tf.train.Saver(var_list=vertex_model.variables)\n", + "\n", + "# The face model generates samples conditioned on a context, which here is\n", + "# the vertex model samples\n", + "face_samples = face_model.sample(\n", + " vertex_samples, max_sample_length=max_num_face_indices, \n", + " top_p=top_p_face_model, only_return_complete=True)\n", + "face_model_saver = tf.train.Saver(var_list=face_model.variables)\n", + "\n", + "# Start sampling\n", + "start = time.time()\n", + "print('Generating samples...')\n", + "with tf.Session() as sess:\n", + " vertex_model_saver.restore(sess, '/tmp/vertex_model/model')\n", + " face_model_saver.restore(sess, '/tmp/face_model/model')\n", + " mesh_list = []\n", + " num_samples_complete = 0\n", + " while num_samples_complete \u003c num_samples_min:\n", + " v_samples_np = sess.run(vertex_samples)\n", + " if v_samples_np['completed'].size == 0:\n", + " print('No vertex samples completed in this batch. Try increasing ' +\n", + " 'max_num_vertices.')\n", + " continue\n", + " f_samples_np = sess.run(\n", + " face_samples,\n", + " {vertex_samples[k]: v_samples_np[k] for k in vertex_samples.keys()})\n", + " v_samples_np = f_samples_np['context']\n", + " num_samples_complete_batch = f_samples_np['completed'].sum()\n", + " num_samples_complete += num_samples_complete_batch\n", + " print('Num. samples complete: {}'.format(num_samples_complete))\n", + " for k in range(num_samples_complete_batch):\n", + " verts = v_samples_np['vertices'][k][:v_samples_np['num_vertices'][k]]\n", + " faces = data_utils.unflatten_faces(\n", + " f_samples_np['faces'][k][:f_samples_np['num_face_indices'][k]])\n", + " mesh_list.append({'vertices': verts, 'faces': faces})\n", + "end = time.time()\n", + "print('sampling time: {}'.format(end - start))\n", + "\n", + "data_utils.plot_meshes(mesh_list, ax_lims=0.4) " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "OOQV6pMvSymz" + }, + "source": [ + "## Export meshes as `.obj` files\n", + "Pick a `mesh_id` (starting at 0) corresponding to the samples generated above. Refresh the colab file browser to find an `.obj` file with the mesh data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "fO0Klbq2Sx0m" + }, + "outputs": [], + "source": [ + "mesh_id = 4 #@param\n", + "data_utils.write_obj(\n", + " mesh_list[mesh_id]['vertices'], mesh_list[mesh_id]['faces'], \n", + " 'mesh-{}.obj'.format(mesh_id))" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "sampling-pretrained.ipynb", + "provenance": [ + { + "file_id": "1yj-oHYqCnwYVGSM22coa68NqnSoNdyVr", + "timestamp": 1591622616999 + }, + { + "file_id": "1v_7DtLnpXrEhVbwZhzDiVQW7ghroi11Y", + "timestamp": 1591264007511 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/polygen/setup.py b/polygen/setup.py new file mode 100644 index 0000000..80191e8 --- /dev/null +++ b/polygen/setup.py @@ -0,0 +1,40 @@ +# Copyright 2020 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Setup for pip package.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from setuptools import find_packages +from setuptools import setup + + +REQUIRED_PACKAGES = ['numpy', 'dm-sonnet==1.36', 'tensorflow==1.14', + 'tensor2tensor==1.15', 'networkx', 'matplotlib', 'six'] + +setup( + name='polygen', + version='0.1', + description='A library for PolyGen: An Autoregressive Generative Model of 3D Meshes.', + url='https://github.com/deepmind/deepmind-research/polygen', + author='DeepMind', + author_email='no-reply@google.com', + # Contained modules and scripts. + packages=find_packages(), + install_requires=REQUIRED_PACKAGES, + platforms=['any'], + license='Apache 2.0', +) diff --git a/polygen/training.ipynb b/polygen/training.ipynb new file mode 100644 index 0000000..b80e4ed --- /dev/null +++ b/polygen/training.ipynb @@ -0,0 +1,328 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "szgmaK1HajOc" + }, + "source": [ + "Copyright 2020 DeepMind Technologies Limited\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "_dv0afOrKheU" + }, + "source": [ + "## Clone repo and import dependencies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ux33ZDQ_tqUV" + }, + "outputs": [], + "source": [ + "!pip install tensorflow==1.15\n", + "!pip install dm-sonnet==1.36\n", + "!pip install tensor2tensor==1.14\n", + "\n", + "import os\n", + "import numpy as np\n", + "import tensorflow.compat.v1 as tf\n", + "tf.logging.set_verbosity(tf.logging.ERROR) # Hide TF deprecation messages\n", + "import matplotlib.pyplot as plt\n", + "\n", + "!git clone https://github.com/deepmind/deepmind-research.git deepmind_research\n", + "%cd deepmind_research/polygen\n", + "import modules\n", + "import data_utils" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "U3GDZhJ5wGOf" + }, + "source": [ + "## Prepare a synthetic dataset\n", + "We prepare a dataset of meshes using four simple geometric primitives.\n", + "\n", + "The important function here is `data_utils.load_process_mesh`, which loads the raw `.obj` file, normalizes and centers the meshes, and applies quantization to the vertex positions. The mesh faces are flattened and treated as a long sequence, with a new-face token (`=1`) separating the faces. For each of the four synthetic meshes, we associate a unique class label, so we can train class-conditional models.\n", + "\n", + "After processing the raw mesh data into numpy arrays, we create a `tf.data.Dataset` that we can use to feed data to our models. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "3QAqwyZjtOdC" + }, + "outputs": [], + "source": [ + "# Prepare synthetic dataset\n", + "ex_list = []\n", + "for k, mesh in enumerate(['cube', 'cylinder', 'cone', 'icosphere']):\n", + " mesh_dict = data_utils.load_process_mesh(\n", + " os.path.join('meshes', '{}.obj'.format(mesh)))\n", + " mesh_dict['class_label'] = k\n", + " ex_list.append(mesh_dict)\n", + "synthetic_dataset = tf.data.Dataset.from_generator(\n", + " lambda: ex_list, \n", + " output_types={\n", + " 'vertices': tf.int32, 'faces': tf.int32, 'class_label': tf.int32},\n", + " output_shapes={\n", + " 'vertices': tf.TensorShape([None, 3]), 'faces': tf.TensorShape([None]), \n", + " 'class_label': tf.TensorShape(())}\n", + " )\n", + "ex = synthetic_dataset.make_one_shot_iterator().get_next()\n", + "\n", + "# Inspect the first mesh\n", + "with tf.Session() as sess:\n", + " ex_np = sess.run(ex)\n", + "print(ex_np)\n", + "\n", + "# Plot the meshes\n", + "mesh_list = []\n", + "with tf.Session() as sess:\n", + " for i in range(4):\n", + " ex_np = sess.run(ex)\n", + " mesh_list.append(\n", + " {'vertices': data_utils.dequantize_verts(ex_np['vertices']), \n", + " 'faces': data_utils.unflatten_faces(ex_np['faces'])})\n", + "data_utils.plot_meshes(mesh_list, ax_lims=0.4)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "9G2FCQQyyTXw" + }, + "source": [ + "## Vertex model\n", + "\n", + "#### Prepare the dataset for vertex model training\n", + "We need to perform some additional processing to make the dataset ready for vertex model training. In particular, `data_utils.make_vertex_model_dataset` flattens the `[V, 3]` vertex arrays, ordering by `Z-\u003eY-\u003eX` coordinates. It also creates masks, which are used to mask padded elements in data batches. We also add random shifts to make the modelling task more challenging.\n", + "\n", + "#### Create a vertex model\n", + "`modules.VertexModel` is a Sonnet module that. Calling the module on a batch of data will produce outputs which are the sequential predictions for each vertex coordinate. The basis of the vertex model is a Transformer decoder, and we specify it's parameters in `decoder_config`. \n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "o2KCoDeeFP8C" + }, + "outputs": [], + "source": [ + "# Prepare the dataset for vertex model training\n", + "vertex_model_dataset = data_utils.make_vertex_model_dataset(\n", + " synthetic_dataset, apply_random_shift=False)\n", + "vertex_model_dataset = vertex_model_dataset.repeat()\n", + "vertex_model_dataset = vertex_model_dataset.padded_batch(\n", + " 4, padded_shapes=vertex_model_dataset.output_shapes)\n", + "vertex_model_dataset = vertex_model_dataset.prefetch(1)\n", + "vertex_model_batch = vertex_model_dataset.make_one_shot_iterator().get_next()\n", + "\n", + "# Create vertex model\n", + "vertex_model = modules.VertexModel(\n", + " decoder_config={\n", + " 'hidden_size': 128,\n", + " 'fc_size': 512, \n", + " 'num_layers': 3,\n", + " 'dropout_rate': 0.\n", + " },\n", + " class_conditional=True,\n", + " num_classes=4,\n", + " max_num_input_verts=250,\n", + " quantization_bits=8,\n", + ")\n", + "vertex_model_pred_dist = vertex_model(vertex_model_batch)\n", + "vertex_model_loss = -tf.reduce_sum(\n", + " vertex_model_pred_dist.log_prob(vertex_model_batch['vertices_flat']) * \n", + " vertex_model_batch['vertices_flat_mask'])\n", + "vertex_samples = vertex_model.sample(\n", + " 4, context=vertex_model_batch, max_sample_length=200, top_p=0.95,\n", + " recenter_verts=False, only_return_complete=False)\n", + "\n", + "print(vertex_model_batch)\n", + "print(vertex_model_pred_dist)\n", + "print(vertex_samples)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-9RNYr5x1jov" + }, + "source": [ + "## Face model\n", + "\n", + "#### Prepare the dataset for face model training\n", + "We need to perform some additional processing to make the dataset ready for vertex model training. In particular, `data_utils.make_vertex_model_dataset` flattens the `[V, 3]` vertex arrays, ordering by `Z-\u003eY-\u003eX` coordinates. It also creates masks, which are used to mask padded elements in data batches. We also add random shifts to make the modelling task more challenging.\n", + "\n", + "#### Create a face model\n", + "`modules.VertexModel` is a Sonnet module that. Calling the module on a batch of data will produce outputs which are the sequential predictions for each vertex coordinate. The basis of the vertex model is a Transformer decoder, and we specify it's parameters in `decoder_config`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "a2yO6dOGzn8c" + }, + "outputs": [], + "source": [ + "face_model_dataset = data_utils.make_face_model_dataset(\n", + " synthetic_dataset, apply_random_shift=False)\n", + "face_model_dataset = face_model_dataset.repeat()\n", + "face_model_dataset = face_model_dataset.padded_batch(\n", + " 4, padded_shapes=face_model_dataset.output_shapes)\n", + "face_model_dataset = face_model_dataset.prefetch(1)\n", + "face_model_batch = face_model_dataset.make_one_shot_iterator().get_next()\n", + "\n", + "# Create face model\n", + "face_model = modules.FaceModel(\n", + " encoder_config={\n", + " 'hidden_size': 128,\n", + " 'fc_size': 512, \n", + " 'num_layers': 3,\n", + " 'dropout_rate': 0.\n", + " },\n", + " decoder_config={\n", + " 'hidden_size': 128,\n", + " 'fc_size': 512, \n", + " 'num_layers': 3,\n", + " 'dropout_rate': 0.\n", + " },\n", + " class_conditional=False,\n", + " max_seq_length=500,\n", + " quantization_bits=8,\n", + " decoder_cross_attention=True,\n", + " use_discrete_vertex_embeddings=True,\n", + ")\n", + "face_model_pred_dist = face_model(face_model_batch)\n", + "face_model_loss = -tf.reduce_sum(\n", + " face_model_pred_dist.log_prob(face_model_batch['faces']) * \n", + " face_model_batch['faces_mask'])\n", + "face_samples = face_model.sample(\n", + " context=vertex_samples, max_sample_length=500, top_p=0.95,\n", + " only_return_complete=False)\n", + "print(face_model_batch)\n", + "print(face_model_pred_dist)\n", + "print(face_samples)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "hL7yloXB1pUb" + }, + "source": [ + "## Train on the synthetic data\n", + "\n", + "Now that we've created vertex and face models and their respective data loaders, we can train them and look at some outputs. While we train the models together here, they can be trained seperately and recombined later if required. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "hjrbofa8zqQt" + }, + "outputs": [], + "source": [ + "# Optimization settings\n", + "learning_rate = 5e-4\n", + "training_steps = 500\n", + "check_step = 5\n", + "\n", + "# Create an optimizer an minimize the summed log probability of the mesh \n", + "# sequences\n", + "optimizer = tf.train.AdamOptimizer(learning_rate)\n", + "vertex_model_optim_op = optimizer.minimize(vertex_model_loss)\n", + "face_model_optim_op = optimizer.minimize(face_model_loss)\n", + "\n", + "# Training loop\n", + "with tf.Session() as sess:\n", + " sess.run(tf.global_variables_initializer())\n", + " for n in range(training_steps):\n", + " if n % check_step == 0:\n", + " v_loss, f_loss = sess.run((vertex_model_loss, face_model_loss))\n", + " print('Step {}'.format(n))\n", + " print('Loss (vertices) {}'.format(v_loss))\n", + " print('Loss (faces) {}'.format(f_loss))\n", + " v_samples_np, f_samples_np, b_np = sess.run(\n", + " (vertex_samples, face_samples, vertex_model_batch))\n", + " mesh_list = []\n", + " for n in range(4):\n", + " mesh_list.append(\n", + " {\n", + " 'vertices': v_samples_np['vertices'][n][:v_samples_np['num_vertices'][n]],\n", + " 'faces': data_utils.unflatten_faces(\n", + " f_samples_np['faces'][n][:f_samples_np['num_face_indices'][n]])\n", + " }\n", + " )\n", + " data_utils.plot_meshes(mesh_list, ax_lims=0.5)\n", + " sess.run((vertex_model_optim_op, face_model_optim_op))" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "training.ipynb", + "provenance": [ + { + "file_id": "1QL8ib2FKPGUWFQbuX8AttUk-H34Al8Ue", + "timestamp": 1591364245034 + }, + { + "file_id": "1v_7DtLnpXrEhVbwZhzDiVQW7ghroi11Y", + "timestamp": 1591355096822 + } + ], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}