Add checkpoints from the ablation study.

PiperOrigin-RevId: 328023346
This commit is contained in:
Florent Altché
2020-08-23 14:26:26 +01:00
committed by Diego de Las Casas
parent 22c3daff19
commit 8457046b2c
33 changed files with 397 additions and 363 deletions
+1 -1
View File
@@ -91,7 +91,7 @@ def make_face_model_dataset(
# Vertices are quantized. So convert to floats for input to face model
example['vertices'] = modules.dequantize_verts(vertices, quantization_bits)
example['vertices_mask'] = tf.ones_like(
example['vertices'][Ellipsis, 0], dtype=tf.float32)
example['vertices'][..., 0], dtype=tf.float32)
example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32)
return example
return ds.map(_face_model_map_fn)
+8 -8
View File
@@ -799,7 +799,7 @@ class VertexModel(snt.AbstractModule):
# Continuous vertex value embeddings
else:
vert_embeddings = tf.layers.dense(
dequantize_verts(vertices[Ellipsis, None], self.quantization_bits),
dequantize_verts(vertices[..., None], self.quantization_bits),
self.embedding_dim,
use_bias=True,
name='value_embeddings')
@@ -984,7 +984,7 @@ class VertexModel(snt.AbstractModule):
verts_dequantized = dequantize_verts(v, self.quantization_bits)
vertices = tf.reshape(verts_dequantized, [num_samples, -1, 3])
vertices = tf.stack(
[vertices[Ellipsis, 2], vertices[Ellipsis, 1], vertices[Ellipsis, 0]], axis=-1)
[vertices[..., 2], vertices[..., 1], vertices[..., 0]], axis=-1)
# Pad samples to max sample length. This is required in order to concatenate
# Samples across different replicator instances. Pad with stopping tokens
@@ -998,14 +998,14 @@ class VertexModel(snt.AbstractModule):
if recenter_verts:
vert_max = tf.reduce_max(
vertices - 1e10 * (1. - vertices_mask)[Ellipsis, None], axis=1,
vertices - 1e10 * (1. - vertices_mask)[..., None], axis=1,
keepdims=True)
vert_min = tf.reduce_min(
vertices + 1e10 * (1. - vertices_mask)[Ellipsis, None], axis=1,
vertices + 1e10 * (1. - vertices_mask)[..., None], axis=1,
keepdims=True)
vert_centers = 0.5 * (vert_max + vert_min)
vertices -= vert_centers
vertices *= vertices_mask[Ellipsis, None]
vertices *= vertices_mask[..., None]
if only_return_complete:
vertices = tf.boolean_mask(vertices, completed)
@@ -1247,7 +1247,7 @@ class FaceModel(snt.AbstractModule):
sequential_context_embeddings = (
vertex_embeddings *
tf.pad(context['vertices_mask'], [[0, 0], [2, 0]],
constant_values=1)[Ellipsis, None])
constant_values=1)[..., None])
else:
sequential_context_embeddings = None
return (vertex_embeddings, global_context_embedding,
@@ -1266,11 +1266,11 @@ class FaceModel(snt.AbstractModule):
embed_dim=self.embedding_dim,
initializers={'embeddings': tf.glorot_uniform_initializer},
densify_gradients=True,
name='coord_{}'.format(c))(verts_quantized[Ellipsis, c])
name='coord_{}'.format(c))(verts_quantized[..., c])
else:
vertex_embeddings = tf.layers.dense(
vertices, self.embedding_dim, use_bias=True, name='vertex_embeddings')
vertex_embeddings *= vertices_mask[Ellipsis, None]
vertex_embeddings *= vertices_mask[..., None]
# Pad vertex embeddings with learned embeddings for stopping and new face
# tokens