Adds PolyGen to public Deepmind Research Github repository

PiperOrigin-RevId: 327802622
This commit is contained in:
Charlie Nash
2020-08-21 14:53:44 +01:00
committed by Saran Tunyasuvunakool
parent 56031adaa4
commit 22c3daff19
14 changed files with 3123 additions and 0 deletions
+2
View File
@@ -50,6 +50,8 @@ https://deepmind.com/research/publications/
* [Graph Matching Networks for Learning the Similarity of Graph Structured
Objects](graph_matching_networks), ICML 2019
* [REGAL: Transfer Learning for Fast Optimization of Computation Graphs](regal)
* [PolyGen: PolyGen: An Autoregressive Generative Model of 3D Meshes](polygen), ICML 2020
## Disclaimer
+65
View File
@@ -0,0 +1,65 @@
# PolyGen: An Autoregressive Generative Model of 3D Meshes
![](media/example_samples.png)
This package provides an implementation of PolyGen as described in:
> **PolyGen: An Autoregressive Generative Model of 3D Meshes**, *Charlie Nash, Yaroslav Ganin, S. M. Ali Eslami, Peter W. Battaglia*, ICML, 2020. ([abs](https://arxiv.org/abs/2002.10880))
PolyGen is a generative model of 3D meshes that sequentially outputs mesh
vertices and faces. PolyGen consists of two parts: A vertex model, that
unconditionally models mesh vertices, and a face model, that models the mesh
faces conditioned on input vertices. The vertex model uses a masked Transformer
decoder to express a distribution over the vertex sequences. For the face model
we combine Transformers with [pointer networks](https://arxiv.org/abs/1506.03134)
to express a distribution over variable length vertex sequences.
In this repository we provide model code in `modules.py`, as well as data
processing utilities in `data_utils.py`. We also provide Colabs that demo
training PolyGen from scratch on a toy dataset, as well as sampling from a
pre-trained model.
There are some minor differences between this implementation and the paper:
* We add global information (e.g. class label embeddings) as an additional input in the first sequence position rather than project it at each layer. This reduces parameters, but does not significantly impact performance.
* We use [ReZero](https://arxiv.org/abs/2003.04887) which improves training speed.
* We train with only shifting augmentations, which we find to be as effective as the combination of augmentations described in the paper. This helps to simplify the data pre-processing pipeline.
## Training Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/polygen/training.ipynb)
To train a PolyGen model from scratch on a collection of simple meshes use this
colab. This demonstrates the data pre-processing required to create inputs for
the vertex and face models.
## Sampling pre-trained model Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/polygen/sample-pretrained.ipynb)
To sample a model pre-trained on [ShapeNet](https://www.shapenet.org/)
use this colab. The model is class-conditional, and is trained on longer
sequence lengths than those described in the paper. This colab uses the
following checkpoints: ([Google Cloud Storage
bucket](https://console.cloud.google.com/storage/browser/deepmind-research-polygen)).
## Installation
To install the package locally run:
```bash
git clone https://github.com/deepmind/deepmind-research.git .
cd deepmind-research/polygen
pip install -e .
```
## Giving Credit
If you use this code in your work, we ask you to cite this paper:
```
@article{nash2020polygen,
author={Charlie Nash and Yaroslav Ganin and S. M. Ali Eslami and Peter W. Battaglia},
title={PolyGen: An Autoregressive Generative Model of 3D Meshes},
journal={ICML},
year={2020}
}
```
## Disclaimer
This is not an official Google product.
+409
View File
@@ -0,0 +1,409 @@
# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Mesh data utilities."""
import matplotlib.pyplot as plt
import modules
from mpl_toolkits import mplot3d # pylint: disable=unused-import
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import networkx as nx
import numpy as np
import six
from six.moves import range
import tensorflow.compat.v1 as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
def random_shift(vertices, shift_factor=0.25):
"""Apply random shift to vertices."""
max_shift_pos = tf.cast(255 - tf.reduce_max(vertices, axis=0), tf.float32)
max_shift_pos = tf.maximum(max_shift_pos, 1e-9)
max_shift_neg = tf.cast(tf.reduce_min(vertices, axis=0), tf.float32)
max_shift_neg = tf.maximum(max_shift_neg, 1e-9)
shift = tfd.TruncatedNormal(
tf.zeros([1, 3]), shift_factor*255, -max_shift_neg,
max_shift_pos).sample()
shift = tf.cast(shift, tf.int32)
vertices += shift
return vertices
def make_vertex_model_dataset(ds, apply_random_shift=False):
"""Prepare dataset for vertex model training."""
def _vertex_model_map_fn(example):
vertices = example['vertices']
# Randomly shift vertices
if apply_random_shift:
vertices = random_shift(vertices)
# Re-order vertex coordinates as (z, y, x).
vertices_permuted = tf.stack(
[vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1)
# Flatten quantized vertices, reindex starting from 1, and pad with a
# zero stopping token.
vertices_flat = tf.reshape(vertices_permuted, [-1])
example['vertices_flat'] = tf.pad(vertices_flat + 1, [[0, 1]])
# Create mask to indicate valid tokens after padding and batching.
example['vertices_flat_mask'] = tf.ones_like(
example['vertices_flat'], dtype=tf.float32)
return example
return ds.map(_vertex_model_map_fn)
def make_face_model_dataset(
ds, apply_random_shift=False, shuffle_vertices=True, quantization_bits=8):
"""Prepare dataset for face model training."""
def _face_model_map_fn(example):
vertices = example['vertices']
# Randomly shift vertices
if apply_random_shift:
vertices = random_shift(vertices)
example['num_vertices'] = tf.shape(vertices)[0]
# Optionally shuffle vertices and re-order faces to match
if shuffle_vertices:
permutation = tf.random_shuffle(tf.range(example['num_vertices']))
vertices = tf.gather(vertices, permutation)
face_permutation = tf.concat(
[tf.constant([0, 1], dtype=tf.int32), tf.argsort(permutation) + 2],
axis=0)
example['faces'] = tf.cast(
tf.gather(face_permutation, example['faces']), tf.int64)
# Vertices are quantized. So convert to floats for input to face model
example['vertices'] = modules.dequantize_verts(vertices, quantization_bits)
example['vertices_mask'] = tf.ones_like(
example['vertices'][Ellipsis, 0], dtype=tf.float32)
example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32)
return example
return ds.map(_face_model_map_fn)
def read_obj(obj_path):
"""Read vertices and faces from .obj file."""
vertex_list = []
flat_vertices_list = []
flat_vertices_indices = {}
flat_triangles = []
with open(obj_path) as obj_file:
for line in obj_file:
tokens = line.split()
if not tokens:
continue
line_type = tokens[0]
# We skip lines not starting with v or f.
if line_type == 'v':
vertex_list.append([float(x) for x in tokens[1:]])
elif line_type == 'f':
triangle = []
for i in range(len(tokens) - 1):
vertex_name = tokens[i + 1]
if vertex_name in flat_vertices_indices:
triangle.append(flat_vertices_indices[vertex_name])
continue
flat_vertex = []
for index in six.ensure_str(vertex_name).split('/'):
if not index:
continue
# obj triangle indices are 1 indexed, so subtract 1 here.
flat_vertex += vertex_list[int(index) - 1]
flat_vertex_index = len(flat_vertices_list)
flat_vertices_list.append(flat_vertex)
flat_vertices_indices[vertex_name] = flat_vertex_index
triangle.append(flat_vertex_index)
flat_triangles.append(triangle)
return np.array(flat_vertices_list, dtype=np.float32), flat_triangles
def write_obj(vertices, faces, file_path, transpose=True, scale=1.):
"""Write vertices and faces to obj."""
if transpose:
vertices = vertices[:, [1, 2, 0]]
vertices *= scale
if faces is not None:
if min(min(faces)) == 0:
f_add = 1
else:
f_add = 0
with open(file_path, 'w') as f:
for v in vertices:
f.write('v {} {} {}\n'.format(v[0], v[1], v[2]))
for face in faces:
line = 'f'
for i in face:
line += ' {}'.format(i + f_add)
line += '\n'
f.write(line)
def quantize_verts(verts, n_bits=8):
"""Convert vertices in [-1., 1.] to discrete values in [0, n_bits**2 - 1]."""
min_range = -0.5
max_range = 0.5
range_quantize = 2**n_bits - 1
verts_quantize = (verts - min_range) * range_quantize / (
max_range - min_range)
return verts_quantize.astype('int32')
def dequantize_verts(verts, n_bits=8, add_noise=False):
"""Convert quantized vertices to floats."""
min_range = -0.5
max_range = 0.5
range_quantize = 2**n_bits - 1
verts = verts.astype('float32')
verts = verts * (max_range - min_range) / range_quantize + min_range
if add_noise:
verts += np.random.uniform(size=verts.shape) * (1 / range_quantize)
return verts
def face_to_cycles(face):
"""Find cycles in face."""
g = nx.Graph()
for v in range(len(face) - 1):
g.add_edge(face[v], face[v + 1])
g.add_edge(face[-1], face[0])
return list(nx.cycle_basis(g))
def flatten_faces(faces):
"""Converts from list of faces to flat face array with stopping indices."""
if not faces:
return np.array([0])
else:
l = [f + [-1] for f in faces[:-1]]
l += [faces[-1] + [-2]]
return np.array([item for sublist in l for item in sublist]) + 2 # pylint: disable=g-complex-comprehension
def unflatten_faces(flat_faces):
"""Converts from flat face sequence to a list of separate faces."""
def group(seq):
g = []
for el in seq:
if el == 0 or el == -1:
yield g
g = []
else:
g.append(el - 1)
yield g
outputs = list(group(flat_faces - 1))[:-1]
# Remove empty faces
return [o for o in outputs if len(o) > 2]
def center_vertices(vertices):
"""Translate the vertices so that bounding box is centered at zero."""
vert_min = vertices.min(axis=0)
vert_max = vertices.max(axis=0)
vert_center = 0.5 * (vert_min + vert_max)
return vertices - vert_center
def normalize_vertices_scale(vertices):
"""Scale the vertices so that the long diagonal of the bounding box is one."""
vert_min = vertices.min(axis=0)
vert_max = vertices.max(axis=0)
extents = vert_max - vert_min
scale = np.sqrt(np.sum(extents**2))
return vertices / scale
def quantize_process_mesh(vertices, faces, tris=None, quantization_bits=8):
"""Quantize vertices, remove resulting duplicates and reindex faces."""
vertices = quantize_verts(vertices, quantization_bits)
vertices, inv = np.unique(vertices, axis=0, return_inverse=True)
# Sort vertices by z then y then x.
sort_inds = np.lexsort(vertices.T)
vertices = vertices[sort_inds]
# Re-index faces and tris to re-ordered vertices.
faces = [np.argsort(sort_inds)[inv[f]] for f in faces]
if tris is not None:
tris = np.array([np.argsort(sort_inds)[inv[t]] for t in tris])
# Merging duplicate vertices and re-indexing the faces causes some faces to
# contain loops (e.g [2, 3, 5, 2, 4]). Split these faces into distinct
# sub-faces.
sub_faces = []
for f in faces:
cliques = face_to_cycles(f)
for c in cliques:
c_length = len(c)
# Only append faces with more than two verts.
if c_length > 2:
d = np.argmin(c)
# Cyclically permute faces just that first index is the smallest.
sub_faces.append([c[(d + i) % c_length] for i in range(c_length)])
faces = sub_faces
if tris is not None:
tris = np.array([v for v in tris if len(set(v)) == len(v)])
# Sort faces by lowest vertex indices. If two faces have the same lowest
# index then sort by next lowest and so on.
faces.sort(key=lambda f: tuple(sorted(f)))
if tris is not None:
tris = tris.tolist()
tris.sort(key=lambda f: tuple(sorted(f)))
tris = np.array(tris)
# After removing degenerate faces some vertices are now unreferenced.
# Remove these.
num_verts = vertices.shape[0]
vert_connected = np.equal(
np.arange(num_verts)[:, None], np.hstack(faces)[None]).any(axis=-1)
vertices = vertices[vert_connected]
# Re-index faces and tris to re-ordered vertices.
vert_indices = (
np.arange(num_verts) - np.cumsum(1 - vert_connected.astype('int')))
faces = [vert_indices[f].tolist() for f in faces]
if tris is not None:
tris = np.array([vert_indices[t].tolist() for t in tris])
return vertices, faces, tris
def load_process_mesh(mesh_obj_path, quantization_bits=8):
"""Load obj file and process."""
# Load mesh
vertices, faces = read_obj(mesh_obj_path)
# Transpose so that z-axis is vertical.
vertices = vertices[:, [2, 0, 1]]
# Translate the vertices so that bounding box is centered at zero.
vertices = center_vertices(vertices)
# Scale the vertices so that the long diagonal of the bounding box is equal
# to one.
vertices = normalize_vertices_scale(vertices)
# Quantize and sort vertices, remove resulting duplicates, sort and reindex
# faces.
vertices, faces, _ = quantize_process_mesh(
vertices, faces, quantization_bits=quantization_bits)
# Flatten faces and add 'new face' = 1 and 'stop' = 0 tokens.
faces = flatten_faces(faces)
# Discard degenerate meshes without faces.
return {
'vertices': vertices,
'faces': faces,
}
def plot_meshes(mesh_list,
ax_lims=0.3,
fig_size=4,
el=30,
rot_start=120,
vert_size=10,
vert_alpha=0.75,
n_cols=4):
"""Plots mesh data using matplotlib."""
n_plot = len(mesh_list)
n_cols = np.minimum(n_plot, n_cols)
n_rows = np.ceil(n_plot / n_cols).astype('int')
fig = plt.figure(figsize=(fig_size * n_cols, fig_size * n_rows))
for p_inc, mesh in enumerate(mesh_list):
for key in [
'vertices', 'faces', 'vertices_conditional', 'pointcloud', 'class_name'
]:
if key not in list(mesh.keys()):
mesh[key] = None
ax = fig.add_subplot(n_rows, n_cols, p_inc + 1, projection='3d')
if mesh['faces'] is not None:
if mesh['vertices_conditional'] is not None:
face_verts = np.concatenate(
[mesh['vertices_conditional'], mesh['vertices']], axis=0)
else:
face_verts = mesh['vertices']
collection = []
for f in mesh['faces']:
collection.append(face_verts[f])
plt_mesh = Poly3DCollection(collection)
plt_mesh.set_edgecolor((0., 0., 0., 0.3))
plt_mesh.set_facecolor((1, 0, 0, 0.2))
ax.add_collection3d(plt_mesh)
if mesh['vertices'] is not None:
ax.scatter3D(
mesh['vertices'][:, 0],
mesh['vertices'][:, 1],
mesh['vertices'][:, 2],
lw=0.,
s=vert_size,
c='g',
alpha=vert_alpha)
if mesh['vertices_conditional'] is not None:
ax.scatter3D(
mesh['vertices_conditional'][:, 0],
mesh['vertices_conditional'][:, 1],
mesh['vertices_conditional'][:, 2],
lw=0.,
s=vert_size,
c='b',
alpha=vert_alpha)
if mesh['pointcloud'] is not None:
ax.scatter3D(
mesh['pointcloud'][:, 0],
mesh['pointcloud'][:, 1],
mesh['pointcloud'][:, 2],
lw=0.,
s=2.5 * vert_size,
c='b',
alpha=1.)
ax.set_xlim(-ax_lims, ax_lims)
ax.set_ylim(-ax_lims, ax_lims)
ax.set_zlim(-ax_lims, ax_lims)
ax.view_init(el, rot_start)
display_string = ''
if mesh['faces'] is not None:
display_string += 'Num. faces: {}\n'.format(len(collection))
if mesh['vertices'] is not None:
num_verts = mesh['vertices'].shape[0]
if mesh['vertices_conditional'] is not None:
num_verts += mesh['vertices_conditional'].shape[0]
display_string += 'Num. verts: {}\n'.format(num_verts)
if mesh['class_name'] is not None:
display_string += 'Synset: {}'.format(mesh['class_name'])
if mesh['pointcloud'] is not None:
display_string += 'Num. pointcloud: {}\n'.format(
mesh['pointcloud'].shape[0])
ax.text2D(0.05, 0.8, display_string, transform=ax.transAxes)
plt.subplots_adjust(
left=0., right=1., bottom=0., top=1., wspace=0.025, hspace=0.025)
plt.show()
Binary file not shown.

After

Width:  |  Height:  |  Size: 281 KiB

+66
View File
@@ -0,0 +1,66 @@
v 0.000000 -1.000000 -1.000000
v 0.195090 -1.000000 -0.980785
v 0.382683 -1.000000 -0.923880
v 0.555570 -1.000000 -0.831470
v 0.707107 -1.000000 -0.707107
v 0.831470 -1.000000 -0.555570
v 0.923880 -1.000000 -0.382683
v 0.980785 -1.000000 -0.195090
v 1.000000 -1.000000 -0.000000
v 0.980785 -1.000000 0.195090
v 0.923880 -1.000000 0.382683
v 0.831470 -1.000000 0.555570
v 0.707107 -1.000000 0.707107
v 0.555570 -1.000000 0.831470
v 0.382683 -1.000000 0.923880
v 0.195090 -1.000000 0.980785
v -0.000000 -1.000000 1.000000
v -0.195091 -1.000000 0.980785
v -0.382684 -1.000000 0.923879
v -0.555571 -1.000000 0.831469
v -0.707107 -1.000000 0.707106
v -0.831470 -1.000000 0.555570
v -0.923880 -1.000000 0.382683
v 0.000000 1.000000 0.000000
v -0.980785 -1.000000 0.195089
v -1.000000 -1.000000 -0.000001
v -0.980785 -1.000000 -0.195091
v -0.923879 -1.000000 -0.382684
v -0.831469 -1.000000 -0.555571
v -0.707106 -1.000000 -0.707108
v -0.555569 -1.000000 -0.831470
v -0.382682 -1.000000 -0.923880
v -0.195089 -1.000000 -0.980786
f 1 24 2
f 2 24 3
f 3 24 4
f 4 24 5
f 5 24 6
f 6 24 7
f 7 24 8
f 8 24 9
f 9 24 10
f 10 24 11
f 11 24 12
f 12 24 13
f 13 24 14
f 14 24 15
f 15 24 16
f 16 24 17
f 17 24 18
f 18 24 19
f 19 24 20
f 20 24 21
f 21 24 22
f 22 24 23
f 23 24 25
f 25 24 26
f 26 24 27
f 27 24 28
f 28 24 29
f 29 24 30
f 30 24 31
f 31 24 32
f 32 24 33
f 33 24 1
f 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 25 26 27 28 29 30 31 32 33
+14
View File
@@ -0,0 +1,14 @@
v -1.000000 -1.000000 1.000000
v -1.000000 1.000000 1.000000
v -1.000000 -1.000000 -1.000000
v -1.000000 1.000000 -1.000000
v 1.000000 -1.000000 1.000000
v 1.000000 1.000000 1.000000
v 1.000000 -1.000000 -1.000000
v 1.000000 1.000000 -1.000000
f 1 2 4 3
f 3 4 8 7
f 7 8 6 5
f 5 6 2 1
f 3 7 5 1
f 8 4 2 6
+98
View File
@@ -0,0 +1,98 @@
v 0.000000 -1.000000 -1.000000
v 0.000000 1.000000 -1.000000
v 0.195090 -1.000000 -0.980785
v 0.195090 1.000000 -0.980785
v 0.382683 -1.000000 -0.923880
v 0.382683 1.000000 -0.923880
v 0.555570 -1.000000 -0.831470
v 0.555570 1.000000 -0.831470
v 0.707107 -1.000000 -0.707107
v 0.707107 1.000000 -0.707107
v 0.831470 -1.000000 -0.555570
v 0.831470 1.000000 -0.555570
v 0.923880 -1.000000 -0.382683
v 0.923880 1.000000 -0.382683
v 0.980785 -1.000000 -0.195090
v 0.980785 1.000000 -0.195090
v 1.000000 -1.000000 -0.000000
v 1.000000 1.000000 -0.000000
v 0.980785 -1.000000 0.195090
v 0.980785 1.000000 0.195090
v 0.923880 -1.000000 0.382683
v 0.923880 1.000000 0.382683
v 0.831470 -1.000000 0.555570
v 0.831470 1.000000 0.555570
v 0.707107 -1.000000 0.707107
v 0.707107 1.000000 0.707107
v 0.555570 -1.000000 0.831470
v 0.555570 1.000000 0.831470
v 0.382683 -1.000000 0.923880
v 0.382683 1.000000 0.923880
v 0.195090 -1.000000 0.980785
v 0.195090 1.000000 0.980785
v -0.000000 -1.000000 1.000000
v -0.000000 1.000000 1.000000
v -0.195091 -1.000000 0.980785
v -0.195091 1.000000 0.980785
v -0.382684 -1.000000 0.923879
v -0.382684 1.000000 0.923879
v -0.555571 -1.000000 0.831469
v -0.555571 1.000000 0.831469
v -0.707107 -1.000000 0.707106
v -0.707107 1.000000 0.707106
v -0.831470 -1.000000 0.555570
v -0.831470 1.000000 0.555570
v -0.923880 -1.000000 0.382683
v -0.923880 1.000000 0.382683
v -0.980785 -1.000000 0.195089
v -0.980785 1.000000 0.195089
v -1.000000 -1.000000 -0.000001
v -1.000000 1.000000 -0.000001
v -0.980785 -1.000000 -0.195091
v -0.980785 1.000000 -0.195091
v -0.923879 -1.000000 -0.382684
v -0.923879 1.000000 -0.382684
v -0.831469 -1.000000 -0.555571
v -0.831469 1.000000 -0.555571
v -0.707106 -1.000000 -0.707108
v -0.707106 1.000000 -0.707108
v -0.555569 -1.000000 -0.831470
v -0.555569 1.000000 -0.831470
v -0.382682 -1.000000 -0.923880
v -0.382682 1.000000 -0.923880
v -0.195089 -1.000000 -0.980786
v -0.195089 1.000000 -0.980786
f 1 2 4 3
f 3 4 6 5
f 5 6 8 7
f 7 8 10 9
f 9 10 12 11
f 11 12 14 13
f 13 14 16 15
f 15 16 18 17
f 17 18 20 19
f 19 20 22 21
f 21 22 24 23
f 23 24 26 25
f 25 26 28 27
f 27 28 30 29
f 29 30 32 31
f 31 32 34 33
f 33 34 36 35
f 35 36 38 37
f 37 38 40 39
f 39 40 42 41
f 41 42 44 43
f 43 44 46 45
f 45 46 48 47
f 47 48 50 49
f 49 50 52 51
f 51 52 54 53
f 53 54 56 55
f 55 56 58 57
f 57 58 60 59
f 59 60 62 61
f 4 2 64 62 60 58 56 54 52 50 48 46 44 42 40 38 36 34 32 30 28 26 24 22 20 18 16 14 12 10 8 6
f 61 62 64 63
f 63 64 2 1
f 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39 41 43 45 47 49 51 53 55 57 59 61 63
+122
View File
@@ -0,0 +1,122 @@
v 0.000000 -1.000000 0.000000
v 0.723607 -0.447220 0.525725
v -0.276388 -0.447220 0.850649
v -0.894426 -0.447216 0.000000
v -0.276388 -0.447220 -0.850649
v 0.723607 -0.447220 -0.525725
v 0.276388 0.447220 0.850649
v -0.723607 0.447220 0.525725
v -0.723607 0.447220 -0.525725
v 0.276388 0.447220 -0.850649
v 0.894426 0.447216 0.000000
v 0.000000 1.000000 0.000000
v -0.162456 -0.850654 0.499995
v 0.425323 -0.850654 0.309011
v 0.262869 -0.525738 0.809012
v 0.850648 -0.525736 0.000000
v 0.425323 -0.850654 -0.309011
v -0.525730 -0.850652 0.000000
v -0.688189 -0.525736 0.499997
v -0.162456 -0.850654 -0.499995
v -0.688189 -0.525736 -0.499997
v 0.262869 -0.525738 -0.809012
v 0.951058 0.000000 0.309013
v 0.951058 0.000000 -0.309013
v 0.000000 0.000000 1.000000
v 0.587786 0.000000 0.809017
v -0.951058 0.000000 0.309013
v -0.587786 0.000000 0.809017
v -0.587786 0.000000 -0.809017
v -0.951058 0.000000 -0.309013
v 0.587786 0.000000 -0.809017
v 0.000000 0.000000 -1.000000
v 0.688189 0.525736 0.499997
v -0.262869 0.525738 0.809012
v -0.850648 0.525736 0.000000
v -0.262869 0.525738 -0.809012
v 0.688189 0.525736 -0.499997
v 0.162456 0.850654 0.499995
v 0.525730 0.850652 0.000000
v -0.425323 0.850654 0.309011
v -0.425323 0.850654 -0.309011
v 0.162456 0.850654 -0.499995
f 1 14 13
f 2 14 16
f 1 13 18
f 1 18 20
f 1 20 17
f 2 16 23
f 3 15 25
f 4 19 27
f 5 21 29
f 6 22 31
f 2 23 26
f 3 25 28
f 4 27 30
f 5 29 32
f 6 31 24
f 7 33 38
f 8 34 40
f 9 35 41
f 10 36 42
f 11 37 39
f 39 42 12
f 39 37 42
f 37 10 42
f 42 41 12
f 42 36 41
f 36 9 41
f 41 40 12
f 41 35 40
f 35 8 40
f 40 38 12
f 40 34 38
f 34 7 38
f 38 39 12
f 38 33 39
f 33 11 39
f 24 37 11
f 24 31 37
f 31 10 37
f 32 36 10
f 32 29 36
f 29 9 36
f 30 35 9
f 30 27 35
f 27 8 35
f 28 34 8
f 28 25 34
f 25 7 34
f 26 33 7
f 26 23 33
f 23 11 33
f 31 32 10
f 31 22 32
f 22 5 32
f 29 30 9
f 29 21 30
f 21 4 30
f 27 28 8
f 27 19 28
f 19 3 28
f 25 26 7
f 25 15 26
f 15 2 26
f 23 24 11
f 23 16 24
f 16 6 24
f 17 22 6
f 17 20 22
f 20 5 22
f 20 21 5
f 20 18 21
f 18 4 21
f 18 19 4
f 18 13 19
f 13 3 19
f 16 17 6
f 16 14 17
f 14 1 17
f 13 15 3
f 13 14 15
f 14 2 15
+159
View File
@@ -0,0 +1,159 @@
# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the PolyGen open-source version."""
from modules import FaceModel
from modules import VertexModel
import numpy as np
import tensorflow as tf
_BATCH_SIZE = 4
_TRANSFORMER_CONFIG = {
'num_layers': 2,
'hidden_size': 64,
'fc_size': 256
}
_CLASS_CONDITIONAL = True
_NUM_CLASSES = 4
_NUM_INPUT_VERTS = 50
_NUM_PAD_VERTS = 10
_NUM_INPUT_FACE_INDICES = 200
_QUANTIZATION_BITS = 8
_VERTEX_MODEL_USE_DISCRETE_EMBEDDINGS = True
_FACE_MODEL_DECODER_CROSS_ATTENTION = True
_FACE_MODEL_DISCRETE_EMBEDDINGS = True
_MAX_SAMPLE_LENGTH_VERTS = 10
_MAX_SAMPLE_LENGTH_FACES = 10
def _get_vertex_model_batch():
"""Returns batch with placeholders for vertex model inputs."""
return {
'class_label': tf.range(_BATCH_SIZE),
'vertices_flat': tf.placeholder(
dtype=tf.int32, shape=[_BATCH_SIZE, None]),
}
def _get_face_model_batch():
"""Returns batch with placeholders for face model inputs."""
return {
'vertices': tf.placeholder(
dtype=tf.float32, shape=[_BATCH_SIZE, None, 3]),
'vertices_mask': tf.placeholder(
dtype=tf.float32, shape=[_BATCH_SIZE, None]),
'faces': tf.placeholder(
dtype=tf.int32, shape=[_BATCH_SIZE, None]),
}
class VertexModelTest(tf.test.TestCase):
def setUp(self):
"""Defines a vertex model."""
super(VertexModelTest, self).setUp()
self.model = VertexModel(
decoder_config=_TRANSFORMER_CONFIG,
class_conditional=_CLASS_CONDITIONAL,
num_classes=_NUM_CLASSES,
max_num_input_verts=_NUM_INPUT_VERTS,
quantization_bits=_QUANTIZATION_BITS,
use_discrete_embeddings=_VERTEX_MODEL_USE_DISCRETE_EMBEDDINGS)
def test_model_runs(self):
"""Tests if the model runs without crashing."""
batch = _get_vertex_model_batch()
pred_dist = self.model(batch, is_training=False)
logits = pred_dist.logits
with self.session() as sess:
sess.run(tf.global_variables_initializer())
vertices_flat = np.random.randint(
2**_QUANTIZATION_BITS + 1,
size=[_BATCH_SIZE, _NUM_INPUT_VERTS * 3 + 1])
sess.run(logits, {batch['vertices_flat']: vertices_flat})
def test_sample_outputs_range(self):
"""Tests if the model produces samples in the correct range."""
context = {'class_label': tf.zeros((_BATCH_SIZE,), dtype=tf.int32)}
sample_dict = self.model.sample(
_BATCH_SIZE, max_sample_length=_MAX_SAMPLE_LENGTH_VERTS,
context=context)
with self.session() as sess:
sess.run(tf.global_variables_initializer())
sample_dict_np = sess.run(sample_dict)
in_range = np.logical_and(
0 <= sample_dict_np['vertices'],
sample_dict_np['vertices'] <= 2**_QUANTIZATION_BITS).all()
self.assertTrue(in_range)
class FaceModelTest(tf.test.TestCase):
def setUp(self):
"""Defines a face model."""
super(FaceModelTest, self).setUp()
self.model = FaceModel(
encoder_config=_TRANSFORMER_CONFIG,
decoder_config=_TRANSFORMER_CONFIG,
class_conditional=False,
max_seq_length=_NUM_INPUT_FACE_INDICES,
decoder_cross_attention=_FACE_MODEL_DECODER_CROSS_ATTENTION,
use_discrete_vertex_embeddings=_FACE_MODEL_DISCRETE_EMBEDDINGS,
quantization_bits=_QUANTIZATION_BITS)
def test_model_runs(self):
"""Tests if the model runs without crashing."""
batch = _get_face_model_batch()
pred_dist = self.model(batch, is_training=False)
logits = pred_dist.logits
with self.session() as sess:
sess.run(tf.global_variables_initializer())
vertices = np.random.rand(_BATCH_SIZE, _NUM_INPUT_VERTS, 3) - 0.5
vertices_mask = np.ones([_BATCH_SIZE, _NUM_INPUT_VERTS])
faces = np.random.randint(
_NUM_INPUT_VERTS + 2, size=[_BATCH_SIZE, _NUM_INPUT_FACE_INDICES])
sess.run(
logits,
{batch['vertices']: vertices,
batch['vertices_mask']: vertices_mask,
batch['faces']: faces}
)
def test_sample_outputs_range(self):
"""Tests if the model produces samples in the correct range."""
context = _get_face_model_batch()
del context['faces']
sample_dict = self.model.sample(
context, max_sample_length=_MAX_SAMPLE_LENGTH_FACES)
with self.session() as sess:
sess.run(tf.global_variables_initializer())
# Pad the vertices in order to test that the face model only outputs
# vertex indices in the unpadded range
vertices = np.pad(
np.random.rand(_BATCH_SIZE, _NUM_INPUT_VERTS, 3) - 0.5,
[[0, 0], [0, _NUM_PAD_VERTS], [0, 0]], mode='constant')
vertices_mask = np.pad(
np.ones([_BATCH_SIZE, _NUM_INPUT_VERTS]),
[[0, 0], [0, _NUM_PAD_VERTS]], mode='constant')
sample_dict_np = sess.run(
sample_dict,
{context['vertices']: vertices,
context['vertices_mask']: vertices_mask})
in_range = np.logical_and(
0 <= sample_dict_np['faces'],
sample_dict_np['faces'] <= _NUM_INPUT_VERTS + 1).all()
self.assertTrue(in_range)
if __name__ == '__main__':
tf.test.main()
+1506
View File
File diff suppressed because it is too large Load Diff
+20
View File
@@ -0,0 +1,20 @@
#!/bin/sh
# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python3 -m venv polygen
source polygen/bin/activate
pip3 install .
python3 model_test.py
deactivate
+294
View File
@@ -0,0 +1,294 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "lUh_eWpealmh"
},
"source": [
"Copyright 2020 DeepMind Technologies Limited\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "kYd9gIfGJYZ8"
},
"source": [
"## Clone repo and import dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Ux33ZDQ_tqUV"
},
"outputs": [],
"source": [
"!pip install tensorflow==1.15\n",
"!pip install dm-sonnet==1.36\n",
"!pip install tensor2tensor==1.14\n",
"\n",
"import time\n",
"import numpy as np\n",
"import tensorflow.compat.v1 as tf\n",
"tf.logging.set_verbosity(tf.logging.ERROR) # Hide TF deprecation messages\n",
"import matplotlib.pyplot as plt\n",
"\n",
"!git clone https://github.com/deepmind/deepmind-research.git deepmind_research\n",
"%cd deepmind_research/polygen\n",
"import modules\n",
"import data_utils"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "YUZNqHTVJbm3"
},
"source": [
"## Download pre-trained model weights from Google Cloud Storage"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "LpZjBUmq10gX"
},
"outputs": [],
"source": [
"!mkdir /tmp/vertex_model\n",
"!mkdir /tmp/face_model\n",
"!gsutil cp gs://deepmind-research-polygen/vertex_model.tar.gz /tmp/vertex_model/\n",
"!gsutil cp gs://deepmind-research-polygen/face_model.tar.gz /tmp/face_model/\n",
"!tar xvfz /tmp/vertex_model/vertex_model.tar.gz -C /tmp/vertex_model/\n",
"!tar xvfz /tmp/face_model/face_model.tar.gz -C /tmp/face_model/"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "vXhMOoxaJ3Xb"
},
"source": [
"## Pre-trained model config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "7VGSSS9vJSn-"
},
"outputs": [],
"source": [
"vertex_module_config=dict(\n",
" decoder_config=dict(\n",
" hidden_size=512,\n",
" fc_size=2048,\n",
" num_heads=8,\n",
" layer_norm=True,\n",
" num_layers=24,\n",
" dropout_rate=0.4,\n",
" re_zero=True,\n",
" memory_efficient=True\n",
" ),\n",
" quantization_bits=8,\n",
" class_conditional=True,\n",
" max_num_input_verts=5000,\n",
" use_discrete_embeddings=True,\n",
" )\n",
"\n",
"face_module_config=dict(\n",
" encoder_config=dict(\n",
" hidden_size=512,\n",
" fc_size=2048,\n",
" num_heads=8,\n",
" layer_norm=True,\n",
" num_layers=10,\n",
" dropout_rate=0.2,\n",
" re_zero=True,\n",
" memory_efficient=True,\n",
" ),\n",
" decoder_config=dict(\n",
" hidden_size=512,\n",
" fc_size=2048,\n",
" num_heads=8,\n",
" layer_norm=True,\n",
" num_layers=14,\n",
" dropout_rate=0.2,\n",
" re_zero=True,\n",
" memory_efficient=True,\n",
" ),\n",
" class_conditional=False,\n",
" decoder_cross_attention=True,\n",
" use_discrete_vertex_embeddings=True,\n",
" max_seq_length=8000,\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WNXf_XbuKW4S"
},
"source": [
"## Generate class-conditional samples\n",
"\n",
"Try varying the `class_id` parameter to generate meshes from different object categories. Good classes to try are tables (49), lamps (30), and cabinets (32). \n",
"\n",
"We can also specify the maximum number of vertices / face indices we want to see in the generated meshes using `max_num_vertices` and `max_num_face_indices`. The code will keep generating batches of samples until there are at least `num_samples_min` complete samples with the required number of vertices / faces.\n",
"\n",
"`top_p_vertex_model` and `top_p_face_model` control how varied the outputs are, with `1.` being the most varied, and `0.` the least varied. `0.9` is a good value for both the vertex and face models.\n",
"\n",
"Sampling should take around 2-5 minutes with a colab GPU using the default settings depending on the object class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "kqKMbPJJu3lk"
},
"outputs": [],
"source": [
"class_id = '49) table' #@param ['0) airplane,aeroplane,plane','1) ashcan,trash can,garbage can,wastebin,ash bin,ash-bin,ashbin,dustbin,trash barrel,trash bin','2) bag,traveling bag,travelling bag,grip,suitcase','3) basket,handbasket','4) bathtub,bathing tub,bath,tub','5) bed','6) bench','7) birdhouse','8) bookshelf','9) bottle','10) bowl','11) bus,autobus,coach,charabanc,double-decker,jitney,motorbus,motorcoach,omnibus,passenger vehi','12) cabinet','13) camera,photographic camera','14) can,tin,tin can','15) cap','16) car,auto,automobile,machine,motorcar','17) cellular telephone,cellular phone,cellphone,cell,mobile phone','18) chair','19) clock','20) computer keyboard,keypad','21) dishwasher,dish washer,dishwashing machine','22) display,video display','23) earphone,earpiece,headphone,phone','24) faucet,spigot','25) file,file cabinet,filing cabinet','26) guitar','27) helmet','28) jar','29) knife','30) lamp','31) laptop,laptop computer','32) loudspeaker,speaker,speaker unit,loudspeaker system,speaker system','33) mailbox,letter box','34) microphone,mike','35) microwave,microwave oven','36) motorcycle,bike','37) mug','38) piano,pianoforte,forte-piano','39) pillow','40) pistol,handgun,side arm,shooting iron','41) pot,flowerpot','42) printer,printing machine','43) remote control,remote','44) rifle','45) rocket,projectile','46) skateboard','47) sofa,couch,lounge','48) stove','49) table','50) telephone,phone,telephone set','51) tower','52) train,railroad train','53) vessel,watercraft','54) washer,automatic washer,washing machine']\n",
"num_samples_min = 1 #@param\n",
"num_samples_batch = 8 #@param\n",
"max_num_vertices = 400 #@param\n",
"max_num_face_indices = 2000 #@param\n",
"top_p_vertex_model = 0.9 #@param\n",
"top_p_face_model = 0.9 #@param\n",
"\n",
"tf.reset_default_graph()\n",
"\n",
"# Build models\n",
"vertex_model = modules.VertexModel(**vertex_module_config)\n",
"face_model = modules.FaceModel(**face_module_config)\n",
"\n",
"# Tile out class label to every element in batch\n",
"class_id = int(class_id.split(')')[0])\n",
"vertex_model_context = {'class_label': tf.fill([num_samples_batch,], class_id)}\n",
"vertex_samples = vertex_model.sample(\n",
" num_samples_batch, context=vertex_model_context, \n",
" max_sample_length=max_num_vertices, top_p=top_p_vertex_model, \n",
" recenter_verts=True, only_return_complete=True)\n",
"vertex_model_saver = tf.train.Saver(var_list=vertex_model.variables)\n",
"\n",
"# The face model generates samples conditioned on a context, which here is\n",
"# the vertex model samples\n",
"face_samples = face_model.sample(\n",
" vertex_samples, max_sample_length=max_num_face_indices, \n",
" top_p=top_p_face_model, only_return_complete=True)\n",
"face_model_saver = tf.train.Saver(var_list=face_model.variables)\n",
"\n",
"# Start sampling\n",
"start = time.time()\n",
"print('Generating samples...')\n",
"with tf.Session() as sess:\n",
" vertex_model_saver.restore(sess, '/tmp/vertex_model/model')\n",
" face_model_saver.restore(sess, '/tmp/face_model/model')\n",
" mesh_list = []\n",
" num_samples_complete = 0\n",
" while num_samples_complete \u003c num_samples_min:\n",
" v_samples_np = sess.run(vertex_samples)\n",
" if v_samples_np['completed'].size == 0:\n",
" print('No vertex samples completed in this batch. Try increasing ' +\n",
" 'max_num_vertices.')\n",
" continue\n",
" f_samples_np = sess.run(\n",
" face_samples,\n",
" {vertex_samples[k]: v_samples_np[k] for k in vertex_samples.keys()})\n",
" v_samples_np = f_samples_np['context']\n",
" num_samples_complete_batch = f_samples_np['completed'].sum()\n",
" num_samples_complete += num_samples_complete_batch\n",
" print('Num. samples complete: {}'.format(num_samples_complete))\n",
" for k in range(num_samples_complete_batch):\n",
" verts = v_samples_np['vertices'][k][:v_samples_np['num_vertices'][k]]\n",
" faces = data_utils.unflatten_faces(\n",
" f_samples_np['faces'][k][:f_samples_np['num_face_indices'][k]])\n",
" mesh_list.append({'vertices': verts, 'faces': faces})\n",
"end = time.time()\n",
"print('sampling time: {}'.format(end - start))\n",
"\n",
"data_utils.plot_meshes(mesh_list, ax_lims=0.4) "
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "OOQV6pMvSymz"
},
"source": [
"## Export meshes as `.obj` files\n",
"Pick a `mesh_id` (starting at 0) corresponding to the samples generated above. Refresh the colab file browser to find an `.obj` file with the mesh data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "fO0Klbq2Sx0m"
},
"outputs": [],
"source": [
"mesh_id = 4 #@param\n",
"data_utils.write_obj(\n",
" mesh_list[mesh_id]['vertices'], mesh_list[mesh_id]['faces'], \n",
" 'mesh-{}.obj'.format(mesh_id))"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "sampling-pretrained.ipynb",
"provenance": [
{
"file_id": "1yj-oHYqCnwYVGSM22coa68NqnSoNdyVr",
"timestamp": 1591622616999
},
{
"file_id": "1v_7DtLnpXrEhVbwZhzDiVQW7ghroi11Y",
"timestamp": 1591264007511
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
+40
View File
@@ -0,0 +1,40 @@
# Copyright 2020 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup for pip package."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from setuptools import find_packages
from setuptools import setup
REQUIRED_PACKAGES = ['numpy', 'dm-sonnet==1.36', 'tensorflow==1.14',
'tensor2tensor==1.15', 'networkx', 'matplotlib', 'six']
setup(
name='polygen',
version='0.1',
description='A library for PolyGen: An Autoregressive Generative Model of 3D Meshes.',
url='https://github.com/deepmind/deepmind-research/polygen',
author='DeepMind',
author_email='no-reply@google.com',
# Contained modules and scripts.
packages=find_packages(),
install_requires=REQUIRED_PACKAGES,
platforms=['any'],
license='Apache 2.0',
)
+328
View File
@@ -0,0 +1,328 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "szgmaK1HajOc"
},
"source": [
"Copyright 2020 DeepMind Technologies Limited\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "_dv0afOrKheU"
},
"source": [
"## Clone repo and import dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Ux33ZDQ_tqUV"
},
"outputs": [],
"source": [
"!pip install tensorflow==1.15\n",
"!pip install dm-sonnet==1.36\n",
"!pip install tensor2tensor==1.14\n",
"\n",
"import os\n",
"import numpy as np\n",
"import tensorflow.compat.v1 as tf\n",
"tf.logging.set_verbosity(tf.logging.ERROR) # Hide TF deprecation messages\n",
"import matplotlib.pyplot as plt\n",
"\n",
"!git clone https://github.com/deepmind/deepmind-research.git deepmind_research\n",
"%cd deepmind_research/polygen\n",
"import modules\n",
"import data_utils"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "U3GDZhJ5wGOf"
},
"source": [
"## Prepare a synthetic dataset\n",
"We prepare a dataset of meshes using four simple geometric primitives.\n",
"\n",
"The important function here is `data_utils.load_process_mesh`, which loads the raw `.obj` file, normalizes and centers the meshes, and applies quantization to the vertex positions. The mesh faces are flattened and treated as a long sequence, with a new-face token (`=1`) separating the faces. For each of the four synthetic meshes, we associate a unique class label, so we can train class-conditional models.\n",
"\n",
"After processing the raw mesh data into numpy arrays, we create a `tf.data.Dataset` that we can use to feed data to our models. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "3QAqwyZjtOdC"
},
"outputs": [],
"source": [
"# Prepare synthetic dataset\n",
"ex_list = []\n",
"for k, mesh in enumerate(['cube', 'cylinder', 'cone', 'icosphere']):\n",
" mesh_dict = data_utils.load_process_mesh(\n",
" os.path.join('meshes', '{}.obj'.format(mesh)))\n",
" mesh_dict['class_label'] = k\n",
" ex_list.append(mesh_dict)\n",
"synthetic_dataset = tf.data.Dataset.from_generator(\n",
" lambda: ex_list, \n",
" output_types={\n",
" 'vertices': tf.int32, 'faces': tf.int32, 'class_label': tf.int32},\n",
" output_shapes={\n",
" 'vertices': tf.TensorShape([None, 3]), 'faces': tf.TensorShape([None]), \n",
" 'class_label': tf.TensorShape(())}\n",
" )\n",
"ex = synthetic_dataset.make_one_shot_iterator().get_next()\n",
"\n",
"# Inspect the first mesh\n",
"with tf.Session() as sess:\n",
" ex_np = sess.run(ex)\n",
"print(ex_np)\n",
"\n",
"# Plot the meshes\n",
"mesh_list = []\n",
"with tf.Session() as sess:\n",
" for i in range(4):\n",
" ex_np = sess.run(ex)\n",
" mesh_list.append(\n",
" {'vertices': data_utils.dequantize_verts(ex_np['vertices']), \n",
" 'faces': data_utils.unflatten_faces(ex_np['faces'])})\n",
"data_utils.plot_meshes(mesh_list, ax_lims=0.4)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9G2FCQQyyTXw"
},
"source": [
"## Vertex model\n",
"\n",
"#### Prepare the dataset for vertex model training\n",
"We need to perform some additional processing to make the dataset ready for vertex model training. In particular, `data_utils.make_vertex_model_dataset` flattens the `[V, 3]` vertex arrays, ordering by `Z-\u003eY-\u003eX` coordinates. It also creates masks, which are used to mask padded elements in data batches. We also add random shifts to make the modelling task more challenging.\n",
"\n",
"#### Create a vertex model\n",
"`modules.VertexModel` is a Sonnet module that. Calling the module on a batch of data will produce outputs which are the sequential predictions for each vertex coordinate. The basis of the vertex model is a Transformer decoder, and we specify it's parameters in `decoder_config`. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "o2KCoDeeFP8C"
},
"outputs": [],
"source": [
"# Prepare the dataset for vertex model training\n",
"vertex_model_dataset = data_utils.make_vertex_model_dataset(\n",
" synthetic_dataset, apply_random_shift=False)\n",
"vertex_model_dataset = vertex_model_dataset.repeat()\n",
"vertex_model_dataset = vertex_model_dataset.padded_batch(\n",
" 4, padded_shapes=vertex_model_dataset.output_shapes)\n",
"vertex_model_dataset = vertex_model_dataset.prefetch(1)\n",
"vertex_model_batch = vertex_model_dataset.make_one_shot_iterator().get_next()\n",
"\n",
"# Create vertex model\n",
"vertex_model = modules.VertexModel(\n",
" decoder_config={\n",
" 'hidden_size': 128,\n",
" 'fc_size': 512, \n",
" 'num_layers': 3,\n",
" 'dropout_rate': 0.\n",
" },\n",
" class_conditional=True,\n",
" num_classes=4,\n",
" max_num_input_verts=250,\n",
" quantization_bits=8,\n",
")\n",
"vertex_model_pred_dist = vertex_model(vertex_model_batch)\n",
"vertex_model_loss = -tf.reduce_sum(\n",
" vertex_model_pred_dist.log_prob(vertex_model_batch['vertices_flat']) * \n",
" vertex_model_batch['vertices_flat_mask'])\n",
"vertex_samples = vertex_model.sample(\n",
" 4, context=vertex_model_batch, max_sample_length=200, top_p=0.95,\n",
" recenter_verts=False, only_return_complete=False)\n",
"\n",
"print(vertex_model_batch)\n",
"print(vertex_model_pred_dist)\n",
"print(vertex_samples)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "-9RNYr5x1jov"
},
"source": [
"## Face model\n",
"\n",
"#### Prepare the dataset for face model training\n",
"We need to perform some additional processing to make the dataset ready for vertex model training. In particular, `data_utils.make_vertex_model_dataset` flattens the `[V, 3]` vertex arrays, ordering by `Z-\u003eY-\u003eX` coordinates. It also creates masks, which are used to mask padded elements in data batches. We also add random shifts to make the modelling task more challenging.\n",
"\n",
"#### Create a face model\n",
"`modules.VertexModel` is a Sonnet module that. Calling the module on a batch of data will produce outputs which are the sequential predictions for each vertex coordinate. The basis of the vertex model is a Transformer decoder, and we specify it's parameters in `decoder_config`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "a2yO6dOGzn8c"
},
"outputs": [],
"source": [
"face_model_dataset = data_utils.make_face_model_dataset(\n",
" synthetic_dataset, apply_random_shift=False)\n",
"face_model_dataset = face_model_dataset.repeat()\n",
"face_model_dataset = face_model_dataset.padded_batch(\n",
" 4, padded_shapes=face_model_dataset.output_shapes)\n",
"face_model_dataset = face_model_dataset.prefetch(1)\n",
"face_model_batch = face_model_dataset.make_one_shot_iterator().get_next()\n",
"\n",
"# Create face model\n",
"face_model = modules.FaceModel(\n",
" encoder_config={\n",
" 'hidden_size': 128,\n",
" 'fc_size': 512, \n",
" 'num_layers': 3,\n",
" 'dropout_rate': 0.\n",
" },\n",
" decoder_config={\n",
" 'hidden_size': 128,\n",
" 'fc_size': 512, \n",
" 'num_layers': 3,\n",
" 'dropout_rate': 0.\n",
" },\n",
" class_conditional=False,\n",
" max_seq_length=500,\n",
" quantization_bits=8,\n",
" decoder_cross_attention=True,\n",
" use_discrete_vertex_embeddings=True,\n",
")\n",
"face_model_pred_dist = face_model(face_model_batch)\n",
"face_model_loss = -tf.reduce_sum(\n",
" face_model_pred_dist.log_prob(face_model_batch['faces']) * \n",
" face_model_batch['faces_mask'])\n",
"face_samples = face_model.sample(\n",
" context=vertex_samples, max_sample_length=500, top_p=0.95,\n",
" only_return_complete=False)\n",
"print(face_model_batch)\n",
"print(face_model_pred_dist)\n",
"print(face_samples)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hL7yloXB1pUb"
},
"source": [
"## Train on the synthetic data\n",
"\n",
"Now that we've created vertex and face models and their respective data loaders, we can train them and look at some outputs. While we train the models together here, they can be trained seperately and recombined later if required. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "hjrbofa8zqQt"
},
"outputs": [],
"source": [
"# Optimization settings\n",
"learning_rate = 5e-4\n",
"training_steps = 500\n",
"check_step = 5\n",
"\n",
"# Create an optimizer an minimize the summed log probability of the mesh \n",
"# sequences\n",
"optimizer = tf.train.AdamOptimizer(learning_rate)\n",
"vertex_model_optim_op = optimizer.minimize(vertex_model_loss)\n",
"face_model_optim_op = optimizer.minimize(face_model_loss)\n",
"\n",
"# Training loop\n",
"with tf.Session() as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
" for n in range(training_steps):\n",
" if n % check_step == 0:\n",
" v_loss, f_loss = sess.run((vertex_model_loss, face_model_loss))\n",
" print('Step {}'.format(n))\n",
" print('Loss (vertices) {}'.format(v_loss))\n",
" print('Loss (faces) {}'.format(f_loss))\n",
" v_samples_np, f_samples_np, b_np = sess.run(\n",
" (vertex_samples, face_samples, vertex_model_batch))\n",
" mesh_list = []\n",
" for n in range(4):\n",
" mesh_list.append(\n",
" {\n",
" 'vertices': v_samples_np['vertices'][n][:v_samples_np['num_vertices'][n]],\n",
" 'faces': data_utils.unflatten_faces(\n",
" f_samples_np['faces'][n][:f_samples_np['num_face_indices'][n]])\n",
" }\n",
" )\n",
" data_utils.plot_meshes(mesh_list, ax_lims=0.5)\n",
" sess.run((vertex_model_optim_op, face_model_optim_op))"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "training.ipynb",
"provenance": [
{
"file_id": "1QL8ib2FKPGUWFQbuX8AttUk-H34Al8Ue",
"timestamp": 1591364245034
},
{
"file_id": "1v_7DtLnpXrEhVbwZhzDiVQW7ghroi11Y",
"timestamp": 1591355096822
}
],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}