mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-23 07:45:18 +08:00
Adds PolyGen to public Deepmind Research Github repository
PiperOrigin-RevId: 327802622
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
56031adaa4
commit
22c3daff19
@@ -50,6 +50,8 @@ https://deepmind.com/research/publications/
|
||||
* [Graph Matching Networks for Learning the Similarity of Graph Structured
|
||||
Objects](graph_matching_networks), ICML 2019
|
||||
* [REGAL: Transfer Learning for Fast Optimization of Computation Graphs](regal)
|
||||
* [PolyGen: PolyGen: An Autoregressive Generative Model of 3D Meshes](polygen), ICML 2020
|
||||
|
||||
|
||||
## Disclaimer
|
||||
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
# PolyGen: An Autoregressive Generative Model of 3D Meshes
|
||||
|
||||

|
||||
|
||||
This package provides an implementation of PolyGen as described in:
|
||||
|
||||
> **PolyGen: An Autoregressive Generative Model of 3D Meshes**, *Charlie Nash, Yaroslav Ganin, S. M. Ali Eslami, Peter W. Battaglia*, ICML, 2020. ([abs](https://arxiv.org/abs/2002.10880))
|
||||
|
||||
PolyGen is a generative model of 3D meshes that sequentially outputs mesh
|
||||
vertices and faces. PolyGen consists of two parts: A vertex model, that
|
||||
unconditionally models mesh vertices, and a face model, that models the mesh
|
||||
faces conditioned on input vertices. The vertex model uses a masked Transformer
|
||||
decoder to express a distribution over the vertex sequences. For the face model
|
||||
we combine Transformers with [pointer networks](https://arxiv.org/abs/1506.03134)
|
||||
to express a distribution over variable length vertex sequences.
|
||||
|
||||
In this repository we provide model code in `modules.py`, as well as data
|
||||
processing utilities in `data_utils.py`. We also provide Colabs that demo
|
||||
training PolyGen from scratch on a toy dataset, as well as sampling from a
|
||||
pre-trained model.
|
||||
|
||||
There are some minor differences between this implementation and the paper:
|
||||
* We add global information (e.g. class label embeddings) as an additional input in the first sequence position rather than project it at each layer. This reduces parameters, but does not significantly impact performance.
|
||||
* We use [ReZero](https://arxiv.org/abs/2003.04887) which improves training speed.
|
||||
* We train with only shifting augmentations, which we find to be as effective as the combination of augmentations described in the paper. This helps to simplify the data pre-processing pipeline.
|
||||
|
||||
## Training Colab [](https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/polygen/training.ipynb)
|
||||
|
||||
To train a PolyGen model from scratch on a collection of simple meshes use this
|
||||
colab. This demonstrates the data pre-processing required to create inputs for
|
||||
the vertex and face models.
|
||||
|
||||
## Sampling pre-trained model Colab [](https://colab.research.google.com/github/deepmind/deepmind-research/blob/master/polygen/sample-pretrained.ipynb)
|
||||
|
||||
To sample a model pre-trained on [ShapeNet](https://www.shapenet.org/)
|
||||
use this colab. The model is class-conditional, and is trained on longer
|
||||
sequence lengths than those described in the paper. This colab uses the
|
||||
following checkpoints: ([Google Cloud Storage
|
||||
bucket](https://console.cloud.google.com/storage/browser/deepmind-research-polygen)).
|
||||
|
||||
## Installation
|
||||
|
||||
To install the package locally run:
|
||||
```bash
|
||||
git clone https://github.com/deepmind/deepmind-research.git .
|
||||
cd deepmind-research/polygen
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## Giving Credit
|
||||
|
||||
If you use this code in your work, we ask you to cite this paper:
|
||||
|
||||
```
|
||||
@article{nash2020polygen,
|
||||
author={Charlie Nash and Yaroslav Ganin and S. M. Ali Eslami and Peter W. Battaglia},
|
||||
title={PolyGen: An Autoregressive Generative Model of 3D Meshes},
|
||||
journal={ICML},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This is not an official Google product.
|
||||
@@ -0,0 +1,409 @@
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Mesh data utilities."""
|
||||
import matplotlib.pyplot as plt
|
||||
import modules
|
||||
from mpl_toolkits import mplot3d # pylint: disable=unused-import
|
||||
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import six
|
||||
from six.moves import range
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_probability as tfp
|
||||
tfd = tfp.distributions
|
||||
|
||||
|
||||
def random_shift(vertices, shift_factor=0.25):
|
||||
"""Apply random shift to vertices."""
|
||||
max_shift_pos = tf.cast(255 - tf.reduce_max(vertices, axis=0), tf.float32)
|
||||
max_shift_pos = tf.maximum(max_shift_pos, 1e-9)
|
||||
|
||||
max_shift_neg = tf.cast(tf.reduce_min(vertices, axis=0), tf.float32)
|
||||
max_shift_neg = tf.maximum(max_shift_neg, 1e-9)
|
||||
|
||||
shift = tfd.TruncatedNormal(
|
||||
tf.zeros([1, 3]), shift_factor*255, -max_shift_neg,
|
||||
max_shift_pos).sample()
|
||||
shift = tf.cast(shift, tf.int32)
|
||||
vertices += shift
|
||||
return vertices
|
||||
|
||||
|
||||
def make_vertex_model_dataset(ds, apply_random_shift=False):
|
||||
"""Prepare dataset for vertex model training."""
|
||||
def _vertex_model_map_fn(example):
|
||||
vertices = example['vertices']
|
||||
|
||||
# Randomly shift vertices
|
||||
if apply_random_shift:
|
||||
vertices = random_shift(vertices)
|
||||
|
||||
# Re-order vertex coordinates as (z, y, x).
|
||||
vertices_permuted = tf.stack(
|
||||
[vertices[:, 2], vertices[:, 1], vertices[:, 0]], axis=-1)
|
||||
|
||||
# Flatten quantized vertices, reindex starting from 1, and pad with a
|
||||
# zero stopping token.
|
||||
vertices_flat = tf.reshape(vertices_permuted, [-1])
|
||||
example['vertices_flat'] = tf.pad(vertices_flat + 1, [[0, 1]])
|
||||
|
||||
# Create mask to indicate valid tokens after padding and batching.
|
||||
example['vertices_flat_mask'] = tf.ones_like(
|
||||
example['vertices_flat'], dtype=tf.float32)
|
||||
return example
|
||||
return ds.map(_vertex_model_map_fn)
|
||||
|
||||
|
||||
def make_face_model_dataset(
|
||||
ds, apply_random_shift=False, shuffle_vertices=True, quantization_bits=8):
|
||||
"""Prepare dataset for face model training."""
|
||||
def _face_model_map_fn(example):
|
||||
vertices = example['vertices']
|
||||
|
||||
# Randomly shift vertices
|
||||
if apply_random_shift:
|
||||
vertices = random_shift(vertices)
|
||||
example['num_vertices'] = tf.shape(vertices)[0]
|
||||
|
||||
# Optionally shuffle vertices and re-order faces to match
|
||||
if shuffle_vertices:
|
||||
permutation = tf.random_shuffle(tf.range(example['num_vertices']))
|
||||
vertices = tf.gather(vertices, permutation)
|
||||
face_permutation = tf.concat(
|
||||
[tf.constant([0, 1], dtype=tf.int32), tf.argsort(permutation) + 2],
|
||||
axis=0)
|
||||
example['faces'] = tf.cast(
|
||||
tf.gather(face_permutation, example['faces']), tf.int64)
|
||||
|
||||
# Vertices are quantized. So convert to floats for input to face model
|
||||
example['vertices'] = modules.dequantize_verts(vertices, quantization_bits)
|
||||
example['vertices_mask'] = tf.ones_like(
|
||||
example['vertices'][Ellipsis, 0], dtype=tf.float32)
|
||||
example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32)
|
||||
return example
|
||||
return ds.map(_face_model_map_fn)
|
||||
|
||||
|
||||
def read_obj(obj_path):
|
||||
"""Read vertices and faces from .obj file."""
|
||||
vertex_list = []
|
||||
flat_vertices_list = []
|
||||
flat_vertices_indices = {}
|
||||
flat_triangles = []
|
||||
|
||||
with open(obj_path) as obj_file:
|
||||
for line in obj_file:
|
||||
tokens = line.split()
|
||||
if not tokens:
|
||||
continue
|
||||
line_type = tokens[0]
|
||||
# We skip lines not starting with v or f.
|
||||
if line_type == 'v':
|
||||
vertex_list.append([float(x) for x in tokens[1:]])
|
||||
elif line_type == 'f':
|
||||
triangle = []
|
||||
for i in range(len(tokens) - 1):
|
||||
vertex_name = tokens[i + 1]
|
||||
if vertex_name in flat_vertices_indices:
|
||||
triangle.append(flat_vertices_indices[vertex_name])
|
||||
continue
|
||||
flat_vertex = []
|
||||
for index in six.ensure_str(vertex_name).split('/'):
|
||||
if not index:
|
||||
continue
|
||||
# obj triangle indices are 1 indexed, so subtract 1 here.
|
||||
flat_vertex += vertex_list[int(index) - 1]
|
||||
flat_vertex_index = len(flat_vertices_list)
|
||||
flat_vertices_list.append(flat_vertex)
|
||||
flat_vertices_indices[vertex_name] = flat_vertex_index
|
||||
triangle.append(flat_vertex_index)
|
||||
flat_triangles.append(triangle)
|
||||
|
||||
return np.array(flat_vertices_list, dtype=np.float32), flat_triangles
|
||||
|
||||
|
||||
def write_obj(vertices, faces, file_path, transpose=True, scale=1.):
|
||||
"""Write vertices and faces to obj."""
|
||||
if transpose:
|
||||
vertices = vertices[:, [1, 2, 0]]
|
||||
vertices *= scale
|
||||
if faces is not None:
|
||||
if min(min(faces)) == 0:
|
||||
f_add = 1
|
||||
else:
|
||||
f_add = 0
|
||||
with open(file_path, 'w') as f:
|
||||
for v in vertices:
|
||||
f.write('v {} {} {}\n'.format(v[0], v[1], v[2]))
|
||||
for face in faces:
|
||||
line = 'f'
|
||||
for i in face:
|
||||
line += ' {}'.format(i + f_add)
|
||||
line += '\n'
|
||||
f.write(line)
|
||||
|
||||
|
||||
def quantize_verts(verts, n_bits=8):
|
||||
"""Convert vertices in [-1., 1.] to discrete values in [0, n_bits**2 - 1]."""
|
||||
min_range = -0.5
|
||||
max_range = 0.5
|
||||
range_quantize = 2**n_bits - 1
|
||||
verts_quantize = (verts - min_range) * range_quantize / (
|
||||
max_range - min_range)
|
||||
return verts_quantize.astype('int32')
|
||||
|
||||
|
||||
def dequantize_verts(verts, n_bits=8, add_noise=False):
|
||||
"""Convert quantized vertices to floats."""
|
||||
min_range = -0.5
|
||||
max_range = 0.5
|
||||
range_quantize = 2**n_bits - 1
|
||||
verts = verts.astype('float32')
|
||||
verts = verts * (max_range - min_range) / range_quantize + min_range
|
||||
if add_noise:
|
||||
verts += np.random.uniform(size=verts.shape) * (1 / range_quantize)
|
||||
return verts
|
||||
|
||||
|
||||
def face_to_cycles(face):
|
||||
"""Find cycles in face."""
|
||||
g = nx.Graph()
|
||||
for v in range(len(face) - 1):
|
||||
g.add_edge(face[v], face[v + 1])
|
||||
g.add_edge(face[-1], face[0])
|
||||
return list(nx.cycle_basis(g))
|
||||
|
||||
|
||||
def flatten_faces(faces):
|
||||
"""Converts from list of faces to flat face array with stopping indices."""
|
||||
if not faces:
|
||||
return np.array([0])
|
||||
else:
|
||||
l = [f + [-1] for f in faces[:-1]]
|
||||
l += [faces[-1] + [-2]]
|
||||
return np.array([item for sublist in l for item in sublist]) + 2 # pylint: disable=g-complex-comprehension
|
||||
|
||||
|
||||
def unflatten_faces(flat_faces):
|
||||
"""Converts from flat face sequence to a list of separate faces."""
|
||||
def group(seq):
|
||||
g = []
|
||||
for el in seq:
|
||||
if el == 0 or el == -1:
|
||||
yield g
|
||||
g = []
|
||||
else:
|
||||
g.append(el - 1)
|
||||
yield g
|
||||
outputs = list(group(flat_faces - 1))[:-1]
|
||||
# Remove empty faces
|
||||
return [o for o in outputs if len(o) > 2]
|
||||
|
||||
|
||||
def center_vertices(vertices):
|
||||
"""Translate the vertices so that bounding box is centered at zero."""
|
||||
vert_min = vertices.min(axis=0)
|
||||
vert_max = vertices.max(axis=0)
|
||||
vert_center = 0.5 * (vert_min + vert_max)
|
||||
return vertices - vert_center
|
||||
|
||||
|
||||
def normalize_vertices_scale(vertices):
|
||||
"""Scale the vertices so that the long diagonal of the bounding box is one."""
|
||||
vert_min = vertices.min(axis=0)
|
||||
vert_max = vertices.max(axis=0)
|
||||
extents = vert_max - vert_min
|
||||
scale = np.sqrt(np.sum(extents**2))
|
||||
return vertices / scale
|
||||
|
||||
|
||||
def quantize_process_mesh(vertices, faces, tris=None, quantization_bits=8):
|
||||
"""Quantize vertices, remove resulting duplicates and reindex faces."""
|
||||
vertices = quantize_verts(vertices, quantization_bits)
|
||||
vertices, inv = np.unique(vertices, axis=0, return_inverse=True)
|
||||
|
||||
# Sort vertices by z then y then x.
|
||||
sort_inds = np.lexsort(vertices.T)
|
||||
vertices = vertices[sort_inds]
|
||||
|
||||
# Re-index faces and tris to re-ordered vertices.
|
||||
faces = [np.argsort(sort_inds)[inv[f]] for f in faces]
|
||||
if tris is not None:
|
||||
tris = np.array([np.argsort(sort_inds)[inv[t]] for t in tris])
|
||||
|
||||
# Merging duplicate vertices and re-indexing the faces causes some faces to
|
||||
# contain loops (e.g [2, 3, 5, 2, 4]). Split these faces into distinct
|
||||
# sub-faces.
|
||||
sub_faces = []
|
||||
for f in faces:
|
||||
cliques = face_to_cycles(f)
|
||||
for c in cliques:
|
||||
c_length = len(c)
|
||||
# Only append faces with more than two verts.
|
||||
if c_length > 2:
|
||||
d = np.argmin(c)
|
||||
# Cyclically permute faces just that first index is the smallest.
|
||||
sub_faces.append([c[(d + i) % c_length] for i in range(c_length)])
|
||||
faces = sub_faces
|
||||
if tris is not None:
|
||||
tris = np.array([v for v in tris if len(set(v)) == len(v)])
|
||||
|
||||
# Sort faces by lowest vertex indices. If two faces have the same lowest
|
||||
# index then sort by next lowest and so on.
|
||||
faces.sort(key=lambda f: tuple(sorted(f)))
|
||||
if tris is not None:
|
||||
tris = tris.tolist()
|
||||
tris.sort(key=lambda f: tuple(sorted(f)))
|
||||
tris = np.array(tris)
|
||||
|
||||
# After removing degenerate faces some vertices are now unreferenced.
|
||||
# Remove these.
|
||||
num_verts = vertices.shape[0]
|
||||
vert_connected = np.equal(
|
||||
np.arange(num_verts)[:, None], np.hstack(faces)[None]).any(axis=-1)
|
||||
vertices = vertices[vert_connected]
|
||||
|
||||
# Re-index faces and tris to re-ordered vertices.
|
||||
vert_indices = (
|
||||
np.arange(num_verts) - np.cumsum(1 - vert_connected.astype('int')))
|
||||
faces = [vert_indices[f].tolist() for f in faces]
|
||||
if tris is not None:
|
||||
tris = np.array([vert_indices[t].tolist() for t in tris])
|
||||
|
||||
return vertices, faces, tris
|
||||
|
||||
|
||||
def load_process_mesh(mesh_obj_path, quantization_bits=8):
|
||||
"""Load obj file and process."""
|
||||
# Load mesh
|
||||
vertices, faces = read_obj(mesh_obj_path)
|
||||
|
||||
# Transpose so that z-axis is vertical.
|
||||
vertices = vertices[:, [2, 0, 1]]
|
||||
|
||||
# Translate the vertices so that bounding box is centered at zero.
|
||||
vertices = center_vertices(vertices)
|
||||
|
||||
# Scale the vertices so that the long diagonal of the bounding box is equal
|
||||
# to one.
|
||||
vertices = normalize_vertices_scale(vertices)
|
||||
|
||||
# Quantize and sort vertices, remove resulting duplicates, sort and reindex
|
||||
# faces.
|
||||
vertices, faces, _ = quantize_process_mesh(
|
||||
vertices, faces, quantization_bits=quantization_bits)
|
||||
|
||||
# Flatten faces and add 'new face' = 1 and 'stop' = 0 tokens.
|
||||
faces = flatten_faces(faces)
|
||||
|
||||
# Discard degenerate meshes without faces.
|
||||
return {
|
||||
'vertices': vertices,
|
||||
'faces': faces,
|
||||
}
|
||||
|
||||
|
||||
def plot_meshes(mesh_list,
|
||||
ax_lims=0.3,
|
||||
fig_size=4,
|
||||
el=30,
|
||||
rot_start=120,
|
||||
vert_size=10,
|
||||
vert_alpha=0.75,
|
||||
n_cols=4):
|
||||
"""Plots mesh data using matplotlib."""
|
||||
|
||||
n_plot = len(mesh_list)
|
||||
n_cols = np.minimum(n_plot, n_cols)
|
||||
n_rows = np.ceil(n_plot / n_cols).astype('int')
|
||||
fig = plt.figure(figsize=(fig_size * n_cols, fig_size * n_rows))
|
||||
for p_inc, mesh in enumerate(mesh_list):
|
||||
|
||||
for key in [
|
||||
'vertices', 'faces', 'vertices_conditional', 'pointcloud', 'class_name'
|
||||
]:
|
||||
if key not in list(mesh.keys()):
|
||||
mesh[key] = None
|
||||
|
||||
ax = fig.add_subplot(n_rows, n_cols, p_inc + 1, projection='3d')
|
||||
|
||||
if mesh['faces'] is not None:
|
||||
if mesh['vertices_conditional'] is not None:
|
||||
face_verts = np.concatenate(
|
||||
[mesh['vertices_conditional'], mesh['vertices']], axis=0)
|
||||
else:
|
||||
face_verts = mesh['vertices']
|
||||
collection = []
|
||||
for f in mesh['faces']:
|
||||
collection.append(face_verts[f])
|
||||
plt_mesh = Poly3DCollection(collection)
|
||||
plt_mesh.set_edgecolor((0., 0., 0., 0.3))
|
||||
plt_mesh.set_facecolor((1, 0, 0, 0.2))
|
||||
ax.add_collection3d(plt_mesh)
|
||||
|
||||
if mesh['vertices'] is not None:
|
||||
ax.scatter3D(
|
||||
mesh['vertices'][:, 0],
|
||||
mesh['vertices'][:, 1],
|
||||
mesh['vertices'][:, 2],
|
||||
lw=0.,
|
||||
s=vert_size,
|
||||
c='g',
|
||||
alpha=vert_alpha)
|
||||
|
||||
if mesh['vertices_conditional'] is not None:
|
||||
ax.scatter3D(
|
||||
mesh['vertices_conditional'][:, 0],
|
||||
mesh['vertices_conditional'][:, 1],
|
||||
mesh['vertices_conditional'][:, 2],
|
||||
lw=0.,
|
||||
s=vert_size,
|
||||
c='b',
|
||||
alpha=vert_alpha)
|
||||
|
||||
if mesh['pointcloud'] is not None:
|
||||
ax.scatter3D(
|
||||
mesh['pointcloud'][:, 0],
|
||||
mesh['pointcloud'][:, 1],
|
||||
mesh['pointcloud'][:, 2],
|
||||
lw=0.,
|
||||
s=2.5 * vert_size,
|
||||
c='b',
|
||||
alpha=1.)
|
||||
|
||||
ax.set_xlim(-ax_lims, ax_lims)
|
||||
ax.set_ylim(-ax_lims, ax_lims)
|
||||
ax.set_zlim(-ax_lims, ax_lims)
|
||||
|
||||
ax.view_init(el, rot_start)
|
||||
|
||||
display_string = ''
|
||||
if mesh['faces'] is not None:
|
||||
display_string += 'Num. faces: {}\n'.format(len(collection))
|
||||
if mesh['vertices'] is not None:
|
||||
num_verts = mesh['vertices'].shape[0]
|
||||
if mesh['vertices_conditional'] is not None:
|
||||
num_verts += mesh['vertices_conditional'].shape[0]
|
||||
display_string += 'Num. verts: {}\n'.format(num_verts)
|
||||
if mesh['class_name'] is not None:
|
||||
display_string += 'Synset: {}'.format(mesh['class_name'])
|
||||
if mesh['pointcloud'] is not None:
|
||||
display_string += 'Num. pointcloud: {}\n'.format(
|
||||
mesh['pointcloud'].shape[0])
|
||||
ax.text2D(0.05, 0.8, display_string, transform=ax.transAxes)
|
||||
plt.subplots_adjust(
|
||||
left=0., right=1., bottom=0., top=1., wspace=0.025, hspace=0.025)
|
||||
plt.show()
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 281 KiB |
@@ -0,0 +1,66 @@
|
||||
v 0.000000 -1.000000 -1.000000
|
||||
v 0.195090 -1.000000 -0.980785
|
||||
v 0.382683 -1.000000 -0.923880
|
||||
v 0.555570 -1.000000 -0.831470
|
||||
v 0.707107 -1.000000 -0.707107
|
||||
v 0.831470 -1.000000 -0.555570
|
||||
v 0.923880 -1.000000 -0.382683
|
||||
v 0.980785 -1.000000 -0.195090
|
||||
v 1.000000 -1.000000 -0.000000
|
||||
v 0.980785 -1.000000 0.195090
|
||||
v 0.923880 -1.000000 0.382683
|
||||
v 0.831470 -1.000000 0.555570
|
||||
v 0.707107 -1.000000 0.707107
|
||||
v 0.555570 -1.000000 0.831470
|
||||
v 0.382683 -1.000000 0.923880
|
||||
v 0.195090 -1.000000 0.980785
|
||||
v -0.000000 -1.000000 1.000000
|
||||
v -0.195091 -1.000000 0.980785
|
||||
v -0.382684 -1.000000 0.923879
|
||||
v -0.555571 -1.000000 0.831469
|
||||
v -0.707107 -1.000000 0.707106
|
||||
v -0.831470 -1.000000 0.555570
|
||||
v -0.923880 -1.000000 0.382683
|
||||
v 0.000000 1.000000 0.000000
|
||||
v -0.980785 -1.000000 0.195089
|
||||
v -1.000000 -1.000000 -0.000001
|
||||
v -0.980785 -1.000000 -0.195091
|
||||
v -0.923879 -1.000000 -0.382684
|
||||
v -0.831469 -1.000000 -0.555571
|
||||
v -0.707106 -1.000000 -0.707108
|
||||
v -0.555569 -1.000000 -0.831470
|
||||
v -0.382682 -1.000000 -0.923880
|
||||
v -0.195089 -1.000000 -0.980786
|
||||
f 1 24 2
|
||||
f 2 24 3
|
||||
f 3 24 4
|
||||
f 4 24 5
|
||||
f 5 24 6
|
||||
f 6 24 7
|
||||
f 7 24 8
|
||||
f 8 24 9
|
||||
f 9 24 10
|
||||
f 10 24 11
|
||||
f 11 24 12
|
||||
f 12 24 13
|
||||
f 13 24 14
|
||||
f 14 24 15
|
||||
f 15 24 16
|
||||
f 16 24 17
|
||||
f 17 24 18
|
||||
f 18 24 19
|
||||
f 19 24 20
|
||||
f 20 24 21
|
||||
f 21 24 22
|
||||
f 22 24 23
|
||||
f 23 24 25
|
||||
f 25 24 26
|
||||
f 26 24 27
|
||||
f 27 24 28
|
||||
f 28 24 29
|
||||
f 29 24 30
|
||||
f 30 24 31
|
||||
f 31 24 32
|
||||
f 32 24 33
|
||||
f 33 24 1
|
||||
f 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 25 26 27 28 29 30 31 32 33
|
||||
@@ -0,0 +1,14 @@
|
||||
v -1.000000 -1.000000 1.000000
|
||||
v -1.000000 1.000000 1.000000
|
||||
v -1.000000 -1.000000 -1.000000
|
||||
v -1.000000 1.000000 -1.000000
|
||||
v 1.000000 -1.000000 1.000000
|
||||
v 1.000000 1.000000 1.000000
|
||||
v 1.000000 -1.000000 -1.000000
|
||||
v 1.000000 1.000000 -1.000000
|
||||
f 1 2 4 3
|
||||
f 3 4 8 7
|
||||
f 7 8 6 5
|
||||
f 5 6 2 1
|
||||
f 3 7 5 1
|
||||
f 8 4 2 6
|
||||
@@ -0,0 +1,98 @@
|
||||
v 0.000000 -1.000000 -1.000000
|
||||
v 0.000000 1.000000 -1.000000
|
||||
v 0.195090 -1.000000 -0.980785
|
||||
v 0.195090 1.000000 -0.980785
|
||||
v 0.382683 -1.000000 -0.923880
|
||||
v 0.382683 1.000000 -0.923880
|
||||
v 0.555570 -1.000000 -0.831470
|
||||
v 0.555570 1.000000 -0.831470
|
||||
v 0.707107 -1.000000 -0.707107
|
||||
v 0.707107 1.000000 -0.707107
|
||||
v 0.831470 -1.000000 -0.555570
|
||||
v 0.831470 1.000000 -0.555570
|
||||
v 0.923880 -1.000000 -0.382683
|
||||
v 0.923880 1.000000 -0.382683
|
||||
v 0.980785 -1.000000 -0.195090
|
||||
v 0.980785 1.000000 -0.195090
|
||||
v 1.000000 -1.000000 -0.000000
|
||||
v 1.000000 1.000000 -0.000000
|
||||
v 0.980785 -1.000000 0.195090
|
||||
v 0.980785 1.000000 0.195090
|
||||
v 0.923880 -1.000000 0.382683
|
||||
v 0.923880 1.000000 0.382683
|
||||
v 0.831470 -1.000000 0.555570
|
||||
v 0.831470 1.000000 0.555570
|
||||
v 0.707107 -1.000000 0.707107
|
||||
v 0.707107 1.000000 0.707107
|
||||
v 0.555570 -1.000000 0.831470
|
||||
v 0.555570 1.000000 0.831470
|
||||
v 0.382683 -1.000000 0.923880
|
||||
v 0.382683 1.000000 0.923880
|
||||
v 0.195090 -1.000000 0.980785
|
||||
v 0.195090 1.000000 0.980785
|
||||
v -0.000000 -1.000000 1.000000
|
||||
v -0.000000 1.000000 1.000000
|
||||
v -0.195091 -1.000000 0.980785
|
||||
v -0.195091 1.000000 0.980785
|
||||
v -0.382684 -1.000000 0.923879
|
||||
v -0.382684 1.000000 0.923879
|
||||
v -0.555571 -1.000000 0.831469
|
||||
v -0.555571 1.000000 0.831469
|
||||
v -0.707107 -1.000000 0.707106
|
||||
v -0.707107 1.000000 0.707106
|
||||
v -0.831470 -1.000000 0.555570
|
||||
v -0.831470 1.000000 0.555570
|
||||
v -0.923880 -1.000000 0.382683
|
||||
v -0.923880 1.000000 0.382683
|
||||
v -0.980785 -1.000000 0.195089
|
||||
v -0.980785 1.000000 0.195089
|
||||
v -1.000000 -1.000000 -0.000001
|
||||
v -1.000000 1.000000 -0.000001
|
||||
v -0.980785 -1.000000 -0.195091
|
||||
v -0.980785 1.000000 -0.195091
|
||||
v -0.923879 -1.000000 -0.382684
|
||||
v -0.923879 1.000000 -0.382684
|
||||
v -0.831469 -1.000000 -0.555571
|
||||
v -0.831469 1.000000 -0.555571
|
||||
v -0.707106 -1.000000 -0.707108
|
||||
v -0.707106 1.000000 -0.707108
|
||||
v -0.555569 -1.000000 -0.831470
|
||||
v -0.555569 1.000000 -0.831470
|
||||
v -0.382682 -1.000000 -0.923880
|
||||
v -0.382682 1.000000 -0.923880
|
||||
v -0.195089 -1.000000 -0.980786
|
||||
v -0.195089 1.000000 -0.980786
|
||||
f 1 2 4 3
|
||||
f 3 4 6 5
|
||||
f 5 6 8 7
|
||||
f 7 8 10 9
|
||||
f 9 10 12 11
|
||||
f 11 12 14 13
|
||||
f 13 14 16 15
|
||||
f 15 16 18 17
|
||||
f 17 18 20 19
|
||||
f 19 20 22 21
|
||||
f 21 22 24 23
|
||||
f 23 24 26 25
|
||||
f 25 26 28 27
|
||||
f 27 28 30 29
|
||||
f 29 30 32 31
|
||||
f 31 32 34 33
|
||||
f 33 34 36 35
|
||||
f 35 36 38 37
|
||||
f 37 38 40 39
|
||||
f 39 40 42 41
|
||||
f 41 42 44 43
|
||||
f 43 44 46 45
|
||||
f 45 46 48 47
|
||||
f 47 48 50 49
|
||||
f 49 50 52 51
|
||||
f 51 52 54 53
|
||||
f 53 54 56 55
|
||||
f 55 56 58 57
|
||||
f 57 58 60 59
|
||||
f 59 60 62 61
|
||||
f 4 2 64 62 60 58 56 54 52 50 48 46 44 42 40 38 36 34 32 30 28 26 24 22 20 18 16 14 12 10 8 6
|
||||
f 61 62 64 63
|
||||
f 63 64 2 1
|
||||
f 1 3 5 7 9 11 13 15 17 19 21 23 25 27 29 31 33 35 37 39 41 43 45 47 49 51 53 55 57 59 61 63
|
||||
@@ -0,0 +1,122 @@
|
||||
v 0.000000 -1.000000 0.000000
|
||||
v 0.723607 -0.447220 0.525725
|
||||
v -0.276388 -0.447220 0.850649
|
||||
v -0.894426 -0.447216 0.000000
|
||||
v -0.276388 -0.447220 -0.850649
|
||||
v 0.723607 -0.447220 -0.525725
|
||||
v 0.276388 0.447220 0.850649
|
||||
v -0.723607 0.447220 0.525725
|
||||
v -0.723607 0.447220 -0.525725
|
||||
v 0.276388 0.447220 -0.850649
|
||||
v 0.894426 0.447216 0.000000
|
||||
v 0.000000 1.000000 0.000000
|
||||
v -0.162456 -0.850654 0.499995
|
||||
v 0.425323 -0.850654 0.309011
|
||||
v 0.262869 -0.525738 0.809012
|
||||
v 0.850648 -0.525736 0.000000
|
||||
v 0.425323 -0.850654 -0.309011
|
||||
v -0.525730 -0.850652 0.000000
|
||||
v -0.688189 -0.525736 0.499997
|
||||
v -0.162456 -0.850654 -0.499995
|
||||
v -0.688189 -0.525736 -0.499997
|
||||
v 0.262869 -0.525738 -0.809012
|
||||
v 0.951058 0.000000 0.309013
|
||||
v 0.951058 0.000000 -0.309013
|
||||
v 0.000000 0.000000 1.000000
|
||||
v 0.587786 0.000000 0.809017
|
||||
v -0.951058 0.000000 0.309013
|
||||
v -0.587786 0.000000 0.809017
|
||||
v -0.587786 0.000000 -0.809017
|
||||
v -0.951058 0.000000 -0.309013
|
||||
v 0.587786 0.000000 -0.809017
|
||||
v 0.000000 0.000000 -1.000000
|
||||
v 0.688189 0.525736 0.499997
|
||||
v -0.262869 0.525738 0.809012
|
||||
v -0.850648 0.525736 0.000000
|
||||
v -0.262869 0.525738 -0.809012
|
||||
v 0.688189 0.525736 -0.499997
|
||||
v 0.162456 0.850654 0.499995
|
||||
v 0.525730 0.850652 0.000000
|
||||
v -0.425323 0.850654 0.309011
|
||||
v -0.425323 0.850654 -0.309011
|
||||
v 0.162456 0.850654 -0.499995
|
||||
f 1 14 13
|
||||
f 2 14 16
|
||||
f 1 13 18
|
||||
f 1 18 20
|
||||
f 1 20 17
|
||||
f 2 16 23
|
||||
f 3 15 25
|
||||
f 4 19 27
|
||||
f 5 21 29
|
||||
f 6 22 31
|
||||
f 2 23 26
|
||||
f 3 25 28
|
||||
f 4 27 30
|
||||
f 5 29 32
|
||||
f 6 31 24
|
||||
f 7 33 38
|
||||
f 8 34 40
|
||||
f 9 35 41
|
||||
f 10 36 42
|
||||
f 11 37 39
|
||||
f 39 42 12
|
||||
f 39 37 42
|
||||
f 37 10 42
|
||||
f 42 41 12
|
||||
f 42 36 41
|
||||
f 36 9 41
|
||||
f 41 40 12
|
||||
f 41 35 40
|
||||
f 35 8 40
|
||||
f 40 38 12
|
||||
f 40 34 38
|
||||
f 34 7 38
|
||||
f 38 39 12
|
||||
f 38 33 39
|
||||
f 33 11 39
|
||||
f 24 37 11
|
||||
f 24 31 37
|
||||
f 31 10 37
|
||||
f 32 36 10
|
||||
f 32 29 36
|
||||
f 29 9 36
|
||||
f 30 35 9
|
||||
f 30 27 35
|
||||
f 27 8 35
|
||||
f 28 34 8
|
||||
f 28 25 34
|
||||
f 25 7 34
|
||||
f 26 33 7
|
||||
f 26 23 33
|
||||
f 23 11 33
|
||||
f 31 32 10
|
||||
f 31 22 32
|
||||
f 22 5 32
|
||||
f 29 30 9
|
||||
f 29 21 30
|
||||
f 21 4 30
|
||||
f 27 28 8
|
||||
f 27 19 28
|
||||
f 19 3 28
|
||||
f 25 26 7
|
||||
f 25 15 26
|
||||
f 15 2 26
|
||||
f 23 24 11
|
||||
f 23 16 24
|
||||
f 16 6 24
|
||||
f 17 22 6
|
||||
f 17 20 22
|
||||
f 20 5 22
|
||||
f 20 21 5
|
||||
f 20 18 21
|
||||
f 18 4 21
|
||||
f 18 19 4
|
||||
f 18 13 19
|
||||
f 13 3 19
|
||||
f 16 17 6
|
||||
f 16 14 17
|
||||
f 14 1 17
|
||||
f 13 15 3
|
||||
f 13 14 15
|
||||
f 14 2 15
|
||||
@@ -0,0 +1,159 @@
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for the PolyGen open-source version."""
|
||||
from modules import FaceModel
|
||||
from modules import VertexModel
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
_BATCH_SIZE = 4
|
||||
_TRANSFORMER_CONFIG = {
|
||||
'num_layers': 2,
|
||||
'hidden_size': 64,
|
||||
'fc_size': 256
|
||||
}
|
||||
_CLASS_CONDITIONAL = True
|
||||
_NUM_CLASSES = 4
|
||||
_NUM_INPUT_VERTS = 50
|
||||
_NUM_PAD_VERTS = 10
|
||||
_NUM_INPUT_FACE_INDICES = 200
|
||||
_QUANTIZATION_BITS = 8
|
||||
_VERTEX_MODEL_USE_DISCRETE_EMBEDDINGS = True
|
||||
_FACE_MODEL_DECODER_CROSS_ATTENTION = True
|
||||
_FACE_MODEL_DISCRETE_EMBEDDINGS = True
|
||||
_MAX_SAMPLE_LENGTH_VERTS = 10
|
||||
_MAX_SAMPLE_LENGTH_FACES = 10
|
||||
|
||||
|
||||
def _get_vertex_model_batch():
|
||||
"""Returns batch with placeholders for vertex model inputs."""
|
||||
return {
|
||||
'class_label': tf.range(_BATCH_SIZE),
|
||||
'vertices_flat': tf.placeholder(
|
||||
dtype=tf.int32, shape=[_BATCH_SIZE, None]),
|
||||
}
|
||||
|
||||
|
||||
def _get_face_model_batch():
|
||||
"""Returns batch with placeholders for face model inputs."""
|
||||
return {
|
||||
'vertices': tf.placeholder(
|
||||
dtype=tf.float32, shape=[_BATCH_SIZE, None, 3]),
|
||||
'vertices_mask': tf.placeholder(
|
||||
dtype=tf.float32, shape=[_BATCH_SIZE, None]),
|
||||
'faces': tf.placeholder(
|
||||
dtype=tf.int32, shape=[_BATCH_SIZE, None]),
|
||||
}
|
||||
|
||||
|
||||
class VertexModelTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Defines a vertex model."""
|
||||
super(VertexModelTest, self).setUp()
|
||||
self.model = VertexModel(
|
||||
decoder_config=_TRANSFORMER_CONFIG,
|
||||
class_conditional=_CLASS_CONDITIONAL,
|
||||
num_classes=_NUM_CLASSES,
|
||||
max_num_input_verts=_NUM_INPUT_VERTS,
|
||||
quantization_bits=_QUANTIZATION_BITS,
|
||||
use_discrete_embeddings=_VERTEX_MODEL_USE_DISCRETE_EMBEDDINGS)
|
||||
|
||||
def test_model_runs(self):
|
||||
"""Tests if the model runs without crashing."""
|
||||
batch = _get_vertex_model_batch()
|
||||
pred_dist = self.model(batch, is_training=False)
|
||||
logits = pred_dist.logits
|
||||
with self.session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
vertices_flat = np.random.randint(
|
||||
2**_QUANTIZATION_BITS + 1,
|
||||
size=[_BATCH_SIZE, _NUM_INPUT_VERTS * 3 + 1])
|
||||
sess.run(logits, {batch['vertices_flat']: vertices_flat})
|
||||
|
||||
def test_sample_outputs_range(self):
|
||||
"""Tests if the model produces samples in the correct range."""
|
||||
context = {'class_label': tf.zeros((_BATCH_SIZE,), dtype=tf.int32)}
|
||||
sample_dict = self.model.sample(
|
||||
_BATCH_SIZE, max_sample_length=_MAX_SAMPLE_LENGTH_VERTS,
|
||||
context=context)
|
||||
with self.session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
sample_dict_np = sess.run(sample_dict)
|
||||
in_range = np.logical_and(
|
||||
0 <= sample_dict_np['vertices'],
|
||||
sample_dict_np['vertices'] <= 2**_QUANTIZATION_BITS).all()
|
||||
self.assertTrue(in_range)
|
||||
|
||||
|
||||
class FaceModelTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Defines a face model."""
|
||||
super(FaceModelTest, self).setUp()
|
||||
self.model = FaceModel(
|
||||
encoder_config=_TRANSFORMER_CONFIG,
|
||||
decoder_config=_TRANSFORMER_CONFIG,
|
||||
class_conditional=False,
|
||||
max_seq_length=_NUM_INPUT_FACE_INDICES,
|
||||
decoder_cross_attention=_FACE_MODEL_DECODER_CROSS_ATTENTION,
|
||||
use_discrete_vertex_embeddings=_FACE_MODEL_DISCRETE_EMBEDDINGS,
|
||||
quantization_bits=_QUANTIZATION_BITS)
|
||||
|
||||
def test_model_runs(self):
|
||||
"""Tests if the model runs without crashing."""
|
||||
batch = _get_face_model_batch()
|
||||
pred_dist = self.model(batch, is_training=False)
|
||||
logits = pred_dist.logits
|
||||
with self.session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
vertices = np.random.rand(_BATCH_SIZE, _NUM_INPUT_VERTS, 3) - 0.5
|
||||
vertices_mask = np.ones([_BATCH_SIZE, _NUM_INPUT_VERTS])
|
||||
faces = np.random.randint(
|
||||
_NUM_INPUT_VERTS + 2, size=[_BATCH_SIZE, _NUM_INPUT_FACE_INDICES])
|
||||
sess.run(
|
||||
logits,
|
||||
{batch['vertices']: vertices,
|
||||
batch['vertices_mask']: vertices_mask,
|
||||
batch['faces']: faces}
|
||||
)
|
||||
|
||||
def test_sample_outputs_range(self):
|
||||
"""Tests if the model produces samples in the correct range."""
|
||||
context = _get_face_model_batch()
|
||||
del context['faces']
|
||||
sample_dict = self.model.sample(
|
||||
context, max_sample_length=_MAX_SAMPLE_LENGTH_FACES)
|
||||
with self.session() as sess:
|
||||
sess.run(tf.global_variables_initializer())
|
||||
# Pad the vertices in order to test that the face model only outputs
|
||||
# vertex indices in the unpadded range
|
||||
vertices = np.pad(
|
||||
np.random.rand(_BATCH_SIZE, _NUM_INPUT_VERTS, 3) - 0.5,
|
||||
[[0, 0], [0, _NUM_PAD_VERTS], [0, 0]], mode='constant')
|
||||
vertices_mask = np.pad(
|
||||
np.ones([_BATCH_SIZE, _NUM_INPUT_VERTS]),
|
||||
[[0, 0], [0, _NUM_PAD_VERTS]], mode='constant')
|
||||
sample_dict_np = sess.run(
|
||||
sample_dict,
|
||||
{context['vertices']: vertices,
|
||||
context['vertices_mask']: vertices_mask})
|
||||
in_range = np.logical_and(
|
||||
0 <= sample_dict_np['faces'],
|
||||
sample_dict_np['faces'] <= _NUM_INPUT_VERTS + 1).all()
|
||||
self.assertTrue(in_range)
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
||||
+1506
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,20 @@
|
||||
#!/bin/sh
|
||||
# Copyright 2020 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
python3 -m venv polygen
|
||||
source polygen/bin/activate
|
||||
pip3 install .
|
||||
python3 model_test.py
|
||||
deactivate
|
||||
@@ -0,0 +1,294 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "lUh_eWpealmh"
|
||||
},
|
||||
"source": [
|
||||
"Copyright 2020 DeepMind Technologies Limited\n",
|
||||
"\n",
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"you may not use this file except in compliance with the License.\n",
|
||||
"You may obtain a copy of the License at\n",
|
||||
"\n",
|
||||
" https://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"\n",
|
||||
"Unless required by applicable law or agreed to in writing, software\n",
|
||||
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"See the License for the specific language governing permissions and\n",
|
||||
"limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "kYd9gIfGJYZ8"
|
||||
},
|
||||
"source": [
|
||||
"## Clone repo and import dependencies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Ux33ZDQ_tqUV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install tensorflow==1.15\n",
|
||||
"!pip install dm-sonnet==1.36\n",
|
||||
"!pip install tensor2tensor==1.14\n",
|
||||
"\n",
|
||||
"import time\n",
|
||||
"import numpy as np\n",
|
||||
"import tensorflow.compat.v1 as tf\n",
|
||||
"tf.logging.set_verbosity(tf.logging.ERROR) # Hide TF deprecation messages\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"!git clone https://github.com/deepmind/deepmind-research.git deepmind_research\n",
|
||||
"%cd deepmind_research/polygen\n",
|
||||
"import modules\n",
|
||||
"import data_utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "YUZNqHTVJbm3"
|
||||
},
|
||||
"source": [
|
||||
"## Download pre-trained model weights from Google Cloud Storage"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "LpZjBUmq10gX"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!mkdir /tmp/vertex_model\n",
|
||||
"!mkdir /tmp/face_model\n",
|
||||
"!gsutil cp gs://deepmind-research-polygen/vertex_model.tar.gz /tmp/vertex_model/\n",
|
||||
"!gsutil cp gs://deepmind-research-polygen/face_model.tar.gz /tmp/face_model/\n",
|
||||
"!tar xvfz /tmp/vertex_model/vertex_model.tar.gz -C /tmp/vertex_model/\n",
|
||||
"!tar xvfz /tmp/face_model/face_model.tar.gz -C /tmp/face_model/"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "vXhMOoxaJ3Xb"
|
||||
},
|
||||
"source": [
|
||||
"## Pre-trained model config"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "7VGSSS9vJSn-"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vertex_module_config=dict(\n",
|
||||
" decoder_config=dict(\n",
|
||||
" hidden_size=512,\n",
|
||||
" fc_size=2048,\n",
|
||||
" num_heads=8,\n",
|
||||
" layer_norm=True,\n",
|
||||
" num_layers=24,\n",
|
||||
" dropout_rate=0.4,\n",
|
||||
" re_zero=True,\n",
|
||||
" memory_efficient=True\n",
|
||||
" ),\n",
|
||||
" quantization_bits=8,\n",
|
||||
" class_conditional=True,\n",
|
||||
" max_num_input_verts=5000,\n",
|
||||
" use_discrete_embeddings=True,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"face_module_config=dict(\n",
|
||||
" encoder_config=dict(\n",
|
||||
" hidden_size=512,\n",
|
||||
" fc_size=2048,\n",
|
||||
" num_heads=8,\n",
|
||||
" layer_norm=True,\n",
|
||||
" num_layers=10,\n",
|
||||
" dropout_rate=0.2,\n",
|
||||
" re_zero=True,\n",
|
||||
" memory_efficient=True,\n",
|
||||
" ),\n",
|
||||
" decoder_config=dict(\n",
|
||||
" hidden_size=512,\n",
|
||||
" fc_size=2048,\n",
|
||||
" num_heads=8,\n",
|
||||
" layer_norm=True,\n",
|
||||
" num_layers=14,\n",
|
||||
" dropout_rate=0.2,\n",
|
||||
" re_zero=True,\n",
|
||||
" memory_efficient=True,\n",
|
||||
" ),\n",
|
||||
" class_conditional=False,\n",
|
||||
" decoder_cross_attention=True,\n",
|
||||
" use_discrete_vertex_embeddings=True,\n",
|
||||
" max_seq_length=8000,\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "WNXf_XbuKW4S"
|
||||
},
|
||||
"source": [
|
||||
"## Generate class-conditional samples\n",
|
||||
"\n",
|
||||
"Try varying the `class_id` parameter to generate meshes from different object categories. Good classes to try are tables (49), lamps (30), and cabinets (32). \n",
|
||||
"\n",
|
||||
"We can also specify the maximum number of vertices / face indices we want to see in the generated meshes using `max_num_vertices` and `max_num_face_indices`. The code will keep generating batches of samples until there are at least `num_samples_min` complete samples with the required number of vertices / faces.\n",
|
||||
"\n",
|
||||
"`top_p_vertex_model` and `top_p_face_model` control how varied the outputs are, with `1.` being the most varied, and `0.` the least varied. `0.9` is a good value for both the vertex and face models.\n",
|
||||
"\n",
|
||||
"Sampling should take around 2-5 minutes with a colab GPU using the default settings depending on the object class."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "kqKMbPJJu3lk"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class_id = '49) table' #@param ['0) airplane,aeroplane,plane','1) ashcan,trash can,garbage can,wastebin,ash bin,ash-bin,ashbin,dustbin,trash barrel,trash bin','2) bag,traveling bag,travelling bag,grip,suitcase','3) basket,handbasket','4) bathtub,bathing tub,bath,tub','5) bed','6) bench','7) birdhouse','8) bookshelf','9) bottle','10) bowl','11) bus,autobus,coach,charabanc,double-decker,jitney,motorbus,motorcoach,omnibus,passenger vehi','12) cabinet','13) camera,photographic camera','14) can,tin,tin can','15) cap','16) car,auto,automobile,machine,motorcar','17) cellular telephone,cellular phone,cellphone,cell,mobile phone','18) chair','19) clock','20) computer keyboard,keypad','21) dishwasher,dish washer,dishwashing machine','22) display,video display','23) earphone,earpiece,headphone,phone','24) faucet,spigot','25) file,file cabinet,filing cabinet','26) guitar','27) helmet','28) jar','29) knife','30) lamp','31) laptop,laptop computer','32) loudspeaker,speaker,speaker unit,loudspeaker system,speaker system','33) mailbox,letter box','34) microphone,mike','35) microwave,microwave oven','36) motorcycle,bike','37) mug','38) piano,pianoforte,forte-piano','39) pillow','40) pistol,handgun,side arm,shooting iron','41) pot,flowerpot','42) printer,printing machine','43) remote control,remote','44) rifle','45) rocket,projectile','46) skateboard','47) sofa,couch,lounge','48) stove','49) table','50) telephone,phone,telephone set','51) tower','52) train,railroad train','53) vessel,watercraft','54) washer,automatic washer,washing machine']\n",
|
||||
"num_samples_min = 1 #@param\n",
|
||||
"num_samples_batch = 8 #@param\n",
|
||||
"max_num_vertices = 400 #@param\n",
|
||||
"max_num_face_indices = 2000 #@param\n",
|
||||
"top_p_vertex_model = 0.9 #@param\n",
|
||||
"top_p_face_model = 0.9 #@param\n",
|
||||
"\n",
|
||||
"tf.reset_default_graph()\n",
|
||||
"\n",
|
||||
"# Build models\n",
|
||||
"vertex_model = modules.VertexModel(**vertex_module_config)\n",
|
||||
"face_model = modules.FaceModel(**face_module_config)\n",
|
||||
"\n",
|
||||
"# Tile out class label to every element in batch\n",
|
||||
"class_id = int(class_id.split(')')[0])\n",
|
||||
"vertex_model_context = {'class_label': tf.fill([num_samples_batch,], class_id)}\n",
|
||||
"vertex_samples = vertex_model.sample(\n",
|
||||
" num_samples_batch, context=vertex_model_context, \n",
|
||||
" max_sample_length=max_num_vertices, top_p=top_p_vertex_model, \n",
|
||||
" recenter_verts=True, only_return_complete=True)\n",
|
||||
"vertex_model_saver = tf.train.Saver(var_list=vertex_model.variables)\n",
|
||||
"\n",
|
||||
"# The face model generates samples conditioned on a context, which here is\n",
|
||||
"# the vertex model samples\n",
|
||||
"face_samples = face_model.sample(\n",
|
||||
" vertex_samples, max_sample_length=max_num_face_indices, \n",
|
||||
" top_p=top_p_face_model, only_return_complete=True)\n",
|
||||
"face_model_saver = tf.train.Saver(var_list=face_model.variables)\n",
|
||||
"\n",
|
||||
"# Start sampling\n",
|
||||
"start = time.time()\n",
|
||||
"print('Generating samples...')\n",
|
||||
"with tf.Session() as sess:\n",
|
||||
" vertex_model_saver.restore(sess, '/tmp/vertex_model/model')\n",
|
||||
" face_model_saver.restore(sess, '/tmp/face_model/model')\n",
|
||||
" mesh_list = []\n",
|
||||
" num_samples_complete = 0\n",
|
||||
" while num_samples_complete \u003c num_samples_min:\n",
|
||||
" v_samples_np = sess.run(vertex_samples)\n",
|
||||
" if v_samples_np['completed'].size == 0:\n",
|
||||
" print('No vertex samples completed in this batch. Try increasing ' +\n",
|
||||
" 'max_num_vertices.')\n",
|
||||
" continue\n",
|
||||
" f_samples_np = sess.run(\n",
|
||||
" face_samples,\n",
|
||||
" {vertex_samples[k]: v_samples_np[k] for k in vertex_samples.keys()})\n",
|
||||
" v_samples_np = f_samples_np['context']\n",
|
||||
" num_samples_complete_batch = f_samples_np['completed'].sum()\n",
|
||||
" num_samples_complete += num_samples_complete_batch\n",
|
||||
" print('Num. samples complete: {}'.format(num_samples_complete))\n",
|
||||
" for k in range(num_samples_complete_batch):\n",
|
||||
" verts = v_samples_np['vertices'][k][:v_samples_np['num_vertices'][k]]\n",
|
||||
" faces = data_utils.unflatten_faces(\n",
|
||||
" f_samples_np['faces'][k][:f_samples_np['num_face_indices'][k]])\n",
|
||||
" mesh_list.append({'vertices': verts, 'faces': faces})\n",
|
||||
"end = time.time()\n",
|
||||
"print('sampling time: {}'.format(end - start))\n",
|
||||
"\n",
|
||||
"data_utils.plot_meshes(mesh_list, ax_lims=0.4) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "OOQV6pMvSymz"
|
||||
},
|
||||
"source": [
|
||||
"## Export meshes as `.obj` files\n",
|
||||
"Pick a `mesh_id` (starting at 0) corresponding to the samples generated above. Refresh the colab file browser to find an `.obj` file with the mesh data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "fO0Klbq2Sx0m"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mesh_id = 4 #@param\n",
|
||||
"data_utils.write_obj(\n",
|
||||
" mesh_list[mesh_id]['vertices'], mesh_list[mesh_id]['faces'], \n",
|
||||
" 'mesh-{}.obj'.format(mesh_id))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "sampling-pretrained.ipynb",
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1yj-oHYqCnwYVGSM22coa68NqnSoNdyVr",
|
||||
"timestamp": 1591622616999
|
||||
},
|
||||
{
|
||||
"file_id": "1v_7DtLnpXrEhVbwZhzDiVQW7ghroi11Y",
|
||||
"timestamp": 1591264007511
|
||||
}
|
||||
]
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Setup for pip package."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from setuptools import find_packages
|
||||
from setuptools import setup
|
||||
|
||||
|
||||
REQUIRED_PACKAGES = ['numpy', 'dm-sonnet==1.36', 'tensorflow==1.14',
|
||||
'tensor2tensor==1.15', 'networkx', 'matplotlib', 'six']
|
||||
|
||||
setup(
|
||||
name='polygen',
|
||||
version='0.1',
|
||||
description='A library for PolyGen: An Autoregressive Generative Model of 3D Meshes.',
|
||||
url='https://github.com/deepmind/deepmind-research/polygen',
|
||||
author='DeepMind',
|
||||
author_email='no-reply@google.com',
|
||||
# Contained modules and scripts.
|
||||
packages=find_packages(),
|
||||
install_requires=REQUIRED_PACKAGES,
|
||||
platforms=['any'],
|
||||
license='Apache 2.0',
|
||||
)
|
||||
@@ -0,0 +1,328 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "szgmaK1HajOc"
|
||||
},
|
||||
"source": [
|
||||
"Copyright 2020 DeepMind Technologies Limited\n",
|
||||
"\n",
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"you may not use this file except in compliance with the License.\n",
|
||||
"You may obtain a copy of the License at\n",
|
||||
"\n",
|
||||
" https://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"\n",
|
||||
"Unless required by applicable law or agreed to in writing, software\n",
|
||||
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"See the License for the specific language governing permissions and\n",
|
||||
"limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "_dv0afOrKheU"
|
||||
},
|
||||
"source": [
|
||||
"## Clone repo and import dependencies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "Ux33ZDQ_tqUV"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install tensorflow==1.15\n",
|
||||
"!pip install dm-sonnet==1.36\n",
|
||||
"!pip install tensor2tensor==1.14\n",
|
||||
"\n",
|
||||
"import os\n",
|
||||
"import numpy as np\n",
|
||||
"import tensorflow.compat.v1 as tf\n",
|
||||
"tf.logging.set_verbosity(tf.logging.ERROR) # Hide TF deprecation messages\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"\n",
|
||||
"!git clone https://github.com/deepmind/deepmind-research.git deepmind_research\n",
|
||||
"%cd deepmind_research/polygen\n",
|
||||
"import modules\n",
|
||||
"import data_utils"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "U3GDZhJ5wGOf"
|
||||
},
|
||||
"source": [
|
||||
"## Prepare a synthetic dataset\n",
|
||||
"We prepare a dataset of meshes using four simple geometric primitives.\n",
|
||||
"\n",
|
||||
"The important function here is `data_utils.load_process_mesh`, which loads the raw `.obj` file, normalizes and centers the meshes, and applies quantization to the vertex positions. The mesh faces are flattened and treated as a long sequence, with a new-face token (`=1`) separating the faces. For each of the four synthetic meshes, we associate a unique class label, so we can train class-conditional models.\n",
|
||||
"\n",
|
||||
"After processing the raw mesh data into numpy arrays, we create a `tf.data.Dataset` that we can use to feed data to our models. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "3QAqwyZjtOdC"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Prepare synthetic dataset\n",
|
||||
"ex_list = []\n",
|
||||
"for k, mesh in enumerate(['cube', 'cylinder', 'cone', 'icosphere']):\n",
|
||||
" mesh_dict = data_utils.load_process_mesh(\n",
|
||||
" os.path.join('meshes', '{}.obj'.format(mesh)))\n",
|
||||
" mesh_dict['class_label'] = k\n",
|
||||
" ex_list.append(mesh_dict)\n",
|
||||
"synthetic_dataset = tf.data.Dataset.from_generator(\n",
|
||||
" lambda: ex_list, \n",
|
||||
" output_types={\n",
|
||||
" 'vertices': tf.int32, 'faces': tf.int32, 'class_label': tf.int32},\n",
|
||||
" output_shapes={\n",
|
||||
" 'vertices': tf.TensorShape([None, 3]), 'faces': tf.TensorShape([None]), \n",
|
||||
" 'class_label': tf.TensorShape(())}\n",
|
||||
" )\n",
|
||||
"ex = synthetic_dataset.make_one_shot_iterator().get_next()\n",
|
||||
"\n",
|
||||
"# Inspect the first mesh\n",
|
||||
"with tf.Session() as sess:\n",
|
||||
" ex_np = sess.run(ex)\n",
|
||||
"print(ex_np)\n",
|
||||
"\n",
|
||||
"# Plot the meshes\n",
|
||||
"mesh_list = []\n",
|
||||
"with tf.Session() as sess:\n",
|
||||
" for i in range(4):\n",
|
||||
" ex_np = sess.run(ex)\n",
|
||||
" mesh_list.append(\n",
|
||||
" {'vertices': data_utils.dequantize_verts(ex_np['vertices']), \n",
|
||||
" 'faces': data_utils.unflatten_faces(ex_np['faces'])})\n",
|
||||
"data_utils.plot_meshes(mesh_list, ax_lims=0.4)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "9G2FCQQyyTXw"
|
||||
},
|
||||
"source": [
|
||||
"## Vertex model\n",
|
||||
"\n",
|
||||
"#### Prepare the dataset for vertex model training\n",
|
||||
"We need to perform some additional processing to make the dataset ready for vertex model training. In particular, `data_utils.make_vertex_model_dataset` flattens the `[V, 3]` vertex arrays, ordering by `Z-\u003eY-\u003eX` coordinates. It also creates masks, which are used to mask padded elements in data batches. We also add random shifts to make the modelling task more challenging.\n",
|
||||
"\n",
|
||||
"#### Create a vertex model\n",
|
||||
"`modules.VertexModel` is a Sonnet module that. Calling the module on a batch of data will produce outputs which are the sequential predictions for each vertex coordinate. The basis of the vertex model is a Transformer decoder, and we specify it's parameters in `decoder_config`. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "o2KCoDeeFP8C"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Prepare the dataset for vertex model training\n",
|
||||
"vertex_model_dataset = data_utils.make_vertex_model_dataset(\n",
|
||||
" synthetic_dataset, apply_random_shift=False)\n",
|
||||
"vertex_model_dataset = vertex_model_dataset.repeat()\n",
|
||||
"vertex_model_dataset = vertex_model_dataset.padded_batch(\n",
|
||||
" 4, padded_shapes=vertex_model_dataset.output_shapes)\n",
|
||||
"vertex_model_dataset = vertex_model_dataset.prefetch(1)\n",
|
||||
"vertex_model_batch = vertex_model_dataset.make_one_shot_iterator().get_next()\n",
|
||||
"\n",
|
||||
"# Create vertex model\n",
|
||||
"vertex_model = modules.VertexModel(\n",
|
||||
" decoder_config={\n",
|
||||
" 'hidden_size': 128,\n",
|
||||
" 'fc_size': 512, \n",
|
||||
" 'num_layers': 3,\n",
|
||||
" 'dropout_rate': 0.\n",
|
||||
" },\n",
|
||||
" class_conditional=True,\n",
|
||||
" num_classes=4,\n",
|
||||
" max_num_input_verts=250,\n",
|
||||
" quantization_bits=8,\n",
|
||||
")\n",
|
||||
"vertex_model_pred_dist = vertex_model(vertex_model_batch)\n",
|
||||
"vertex_model_loss = -tf.reduce_sum(\n",
|
||||
" vertex_model_pred_dist.log_prob(vertex_model_batch['vertices_flat']) * \n",
|
||||
" vertex_model_batch['vertices_flat_mask'])\n",
|
||||
"vertex_samples = vertex_model.sample(\n",
|
||||
" 4, context=vertex_model_batch, max_sample_length=200, top_p=0.95,\n",
|
||||
" recenter_verts=False, only_return_complete=False)\n",
|
||||
"\n",
|
||||
"print(vertex_model_batch)\n",
|
||||
"print(vertex_model_pred_dist)\n",
|
||||
"print(vertex_samples)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "-9RNYr5x1jov"
|
||||
},
|
||||
"source": [
|
||||
"## Face model\n",
|
||||
"\n",
|
||||
"#### Prepare the dataset for face model training\n",
|
||||
"We need to perform some additional processing to make the dataset ready for vertex model training. In particular, `data_utils.make_vertex_model_dataset` flattens the `[V, 3]` vertex arrays, ordering by `Z-\u003eY-\u003eX` coordinates. It also creates masks, which are used to mask padded elements in data batches. We also add random shifts to make the modelling task more challenging.\n",
|
||||
"\n",
|
||||
"#### Create a face model\n",
|
||||
"`modules.VertexModel` is a Sonnet module that. Calling the module on a batch of data will produce outputs which are the sequential predictions for each vertex coordinate. The basis of the vertex model is a Transformer decoder, and we specify it's parameters in `decoder_config`. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "a2yO6dOGzn8c"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"face_model_dataset = data_utils.make_face_model_dataset(\n",
|
||||
" synthetic_dataset, apply_random_shift=False)\n",
|
||||
"face_model_dataset = face_model_dataset.repeat()\n",
|
||||
"face_model_dataset = face_model_dataset.padded_batch(\n",
|
||||
" 4, padded_shapes=face_model_dataset.output_shapes)\n",
|
||||
"face_model_dataset = face_model_dataset.prefetch(1)\n",
|
||||
"face_model_batch = face_model_dataset.make_one_shot_iterator().get_next()\n",
|
||||
"\n",
|
||||
"# Create face model\n",
|
||||
"face_model = modules.FaceModel(\n",
|
||||
" encoder_config={\n",
|
||||
" 'hidden_size': 128,\n",
|
||||
" 'fc_size': 512, \n",
|
||||
" 'num_layers': 3,\n",
|
||||
" 'dropout_rate': 0.\n",
|
||||
" },\n",
|
||||
" decoder_config={\n",
|
||||
" 'hidden_size': 128,\n",
|
||||
" 'fc_size': 512, \n",
|
||||
" 'num_layers': 3,\n",
|
||||
" 'dropout_rate': 0.\n",
|
||||
" },\n",
|
||||
" class_conditional=False,\n",
|
||||
" max_seq_length=500,\n",
|
||||
" quantization_bits=8,\n",
|
||||
" decoder_cross_attention=True,\n",
|
||||
" use_discrete_vertex_embeddings=True,\n",
|
||||
")\n",
|
||||
"face_model_pred_dist = face_model(face_model_batch)\n",
|
||||
"face_model_loss = -tf.reduce_sum(\n",
|
||||
" face_model_pred_dist.log_prob(face_model_batch['faces']) * \n",
|
||||
" face_model_batch['faces_mask'])\n",
|
||||
"face_samples = face_model.sample(\n",
|
||||
" context=vertex_samples, max_sample_length=500, top_p=0.95,\n",
|
||||
" only_return_complete=False)\n",
|
||||
"print(face_model_batch)\n",
|
||||
"print(face_model_pred_dist)\n",
|
||||
"print(face_samples)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "hL7yloXB1pUb"
|
||||
},
|
||||
"source": [
|
||||
"## Train on the synthetic data\n",
|
||||
"\n",
|
||||
"Now that we've created vertex and face models and their respective data loaders, we can train them and look at some outputs. While we train the models together here, they can be trained seperately and recombined later if required. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "hjrbofa8zqQt"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Optimization settings\n",
|
||||
"learning_rate = 5e-4\n",
|
||||
"training_steps = 500\n",
|
||||
"check_step = 5\n",
|
||||
"\n",
|
||||
"# Create an optimizer an minimize the summed log probability of the mesh \n",
|
||||
"# sequences\n",
|
||||
"optimizer = tf.train.AdamOptimizer(learning_rate)\n",
|
||||
"vertex_model_optim_op = optimizer.minimize(vertex_model_loss)\n",
|
||||
"face_model_optim_op = optimizer.minimize(face_model_loss)\n",
|
||||
"\n",
|
||||
"# Training loop\n",
|
||||
"with tf.Session() as sess:\n",
|
||||
" sess.run(tf.global_variables_initializer())\n",
|
||||
" for n in range(training_steps):\n",
|
||||
" if n % check_step == 0:\n",
|
||||
" v_loss, f_loss = sess.run((vertex_model_loss, face_model_loss))\n",
|
||||
" print('Step {}'.format(n))\n",
|
||||
" print('Loss (vertices) {}'.format(v_loss))\n",
|
||||
" print('Loss (faces) {}'.format(f_loss))\n",
|
||||
" v_samples_np, f_samples_np, b_np = sess.run(\n",
|
||||
" (vertex_samples, face_samples, vertex_model_batch))\n",
|
||||
" mesh_list = []\n",
|
||||
" for n in range(4):\n",
|
||||
" mesh_list.append(\n",
|
||||
" {\n",
|
||||
" 'vertices': v_samples_np['vertices'][n][:v_samples_np['num_vertices'][n]],\n",
|
||||
" 'faces': data_utils.unflatten_faces(\n",
|
||||
" f_samples_np['faces'][n][:f_samples_np['num_face_indices'][n]])\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
" data_utils.plot_meshes(mesh_list, ax_lims=0.5)\n",
|
||||
" sess.run((vertex_model_optim_op, face_model_optim_op))"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"name": "training.ipynb",
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1QL8ib2FKPGUWFQbuX8AttUk-H34Al8Ue",
|
||||
"timestamp": 1591364245034
|
||||
},
|
||||
{
|
||||
"file_id": "1v_7DtLnpXrEhVbwZhzDiVQW7ghroi11Y",
|
||||
"timestamp": 1591355096822
|
||||
}
|
||||
],
|
||||
"toc_visible": true
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
Reference in New Issue
Block a user