Files
deepmind-research/polygen/training.ipynb
Tom Ward 3d9f601c69 Update PolyGen installation to do single pip install & clone into /tmp.
This fixes issues with dep versioning resolution between each package.

PiperOrigin-RevId: 391343870
2021-08-27 15:02:27 +01:00

330 lines
13 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "szgmaK1HajOc"
},
"source": [
"Copyright 2020 DeepMind Technologies Limited\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "_dv0afOrKheU"
},
"source": [
"## Clone repo and import dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Ux33ZDQ_tqUV"
},
"outputs": [],
"source": [
"!pip install tensorflow==1.15 dm-sonnet==1.36 tensor2tensor==1.14\n",
"\n",
"import os\n",
"import numpy as np\n",
"import tensorflow.compat.v1 as tf\n",
"tf.logging.set_verbosity(tf.logging.ERROR) # Hide TF deprecation messages\n",
"import matplotlib.pyplot as plt\n",
"\n",
"%cd /tmp\n",
"%rm -rf /tmp/deepmind_research\n",
"!git clone https://github.com/deepmind/deepmind-research.git \\\n",
" /tmp/deepmind_research\n",
"%cd /tmp/deepmind_research/polygen\n",
"import modules\n",
"import data_utils"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "U3GDZhJ5wGOf"
},
"source": [
"## Prepare a synthetic dataset\n",
"We prepare a dataset of meshes using four simple geometric primitives.\n",
"\n",
"The important function here is `data_utils.load_process_mesh`, which loads the raw `.obj` file, normalizes and centers the meshes, and applies quantization to the vertex positions. The mesh faces are flattened and treated as a long sequence, with a new-face token (`=1`) separating the faces. For each of the four synthetic meshes, we associate a unique class label, so we can train class-conditional models.\n",
"\n",
"After processing the raw mesh data into numpy arrays, we create a `tf.data.Dataset` that we can use to feed data to our models. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "3QAqwyZjtOdC"
},
"outputs": [],
"source": [
"# Prepare synthetic dataset\n",
"ex_list = []\n",
"for k, mesh in enumerate(['cube', 'cylinder', 'cone', 'icosphere']):\n",
" mesh_dict = data_utils.load_process_mesh(\n",
" os.path.join('meshes', '{}.obj'.format(mesh)))\n",
" mesh_dict['class_label'] = k\n",
" ex_list.append(mesh_dict)\n",
"synthetic_dataset = tf.data.Dataset.from_generator(\n",
" lambda: ex_list, \n",
" output_types={\n",
" 'vertices': tf.int32, 'faces': tf.int32, 'class_label': tf.int32},\n",
" output_shapes={\n",
" 'vertices': tf.TensorShape([None, 3]), 'faces': tf.TensorShape([None]), \n",
" 'class_label': tf.TensorShape(())}\n",
" )\n",
"ex = synthetic_dataset.make_one_shot_iterator().get_next()\n",
"\n",
"# Inspect the first mesh\n",
"with tf.Session() as sess:\n",
" ex_np = sess.run(ex)\n",
"print(ex_np)\n",
"\n",
"# Plot the meshes\n",
"mesh_list = []\n",
"with tf.Session() as sess:\n",
" for i in range(4):\n",
" ex_np = sess.run(ex)\n",
" mesh_list.append(\n",
" {'vertices': data_utils.dequantize_verts(ex_np['vertices']), \n",
" 'faces': data_utils.unflatten_faces(ex_np['faces'])})\n",
"data_utils.plot_meshes(mesh_list, ax_lims=0.4)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9G2FCQQyyTXw"
},
"source": [
"## Vertex model\n",
"\n",
"#### Prepare the dataset for vertex model training\n",
"We need to perform some additional processing to make the dataset ready for vertex model training. In particular, `data_utils.make_vertex_model_dataset` flattens the `[V, 3]` vertex arrays, ordering by `Z-\u003eY-\u003eX` coordinates. It also creates masks, which are used to mask padded elements in data batches. We also add random shifts to make the modelling task more challenging.\n",
"\n",
"#### Create a vertex model\n",
"`modules.VertexModel` is a Sonnet module that. Calling the module on a batch of data will produce outputs which are the sequential predictions for each vertex coordinate. The basis of the vertex model is a Transformer decoder, and we specify it's parameters in `decoder_config`. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "o2KCoDeeFP8C"
},
"outputs": [],
"source": [
"# Prepare the dataset for vertex model training\n",
"vertex_model_dataset = data_utils.make_vertex_model_dataset(\n",
" synthetic_dataset, apply_random_shift=False)\n",
"vertex_model_dataset = vertex_model_dataset.repeat()\n",
"vertex_model_dataset = vertex_model_dataset.padded_batch(\n",
" 4, padded_shapes=vertex_model_dataset.output_shapes)\n",
"vertex_model_dataset = vertex_model_dataset.prefetch(1)\n",
"vertex_model_batch = vertex_model_dataset.make_one_shot_iterator().get_next()\n",
"\n",
"# Create vertex model\n",
"vertex_model = modules.VertexModel(\n",
" decoder_config={\n",
" 'hidden_size': 128,\n",
" 'fc_size': 512, \n",
" 'num_layers': 3,\n",
" 'dropout_rate': 0.\n",
" },\n",
" class_conditional=True,\n",
" num_classes=4,\n",
" max_num_input_verts=250,\n",
" quantization_bits=8,\n",
")\n",
"vertex_model_pred_dist = vertex_model(vertex_model_batch)\n",
"vertex_model_loss = -tf.reduce_sum(\n",
" vertex_model_pred_dist.log_prob(vertex_model_batch['vertices_flat']) * \n",
" vertex_model_batch['vertices_flat_mask'])\n",
"vertex_samples = vertex_model.sample(\n",
" 4, context=vertex_model_batch, max_sample_length=200, top_p=0.95,\n",
" recenter_verts=False, only_return_complete=False)\n",
"\n",
"print(vertex_model_batch)\n",
"print(vertex_model_pred_dist)\n",
"print(vertex_samples)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "-9RNYr5x1jov"
},
"source": [
"## Face model\n",
"\n",
"#### Prepare the dataset for face model training\n",
"We need to perform some additional processing to make the dataset ready for vertex model training. In particular, `data_utils.make_vertex_model_dataset` flattens the `[V, 3]` vertex arrays, ordering by `Z-\u003eY-\u003eX` coordinates. It also creates masks, which are used to mask padded elements in data batches. We also add random shifts to make the modelling task more challenging.\n",
"\n",
"#### Create a face model\n",
"`modules.VertexModel` is a Sonnet module that. Calling the module on a batch of data will produce outputs which are the sequential predictions for each vertex coordinate. The basis of the vertex model is a Transformer decoder, and we specify it's parameters in `decoder_config`. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "a2yO6dOGzn8c"
},
"outputs": [],
"source": [
"face_model_dataset = data_utils.make_face_model_dataset(\n",
" synthetic_dataset, apply_random_shift=False)\n",
"face_model_dataset = face_model_dataset.repeat()\n",
"face_model_dataset = face_model_dataset.padded_batch(\n",
" 4, padded_shapes=face_model_dataset.output_shapes)\n",
"face_model_dataset = face_model_dataset.prefetch(1)\n",
"face_model_batch = face_model_dataset.make_one_shot_iterator().get_next()\n",
"\n",
"# Create face model\n",
"face_model = modules.FaceModel(\n",
" encoder_config={\n",
" 'hidden_size': 128,\n",
" 'fc_size': 512, \n",
" 'num_layers': 3,\n",
" 'dropout_rate': 0.\n",
" },\n",
" decoder_config={\n",
" 'hidden_size': 128,\n",
" 'fc_size': 512, \n",
" 'num_layers': 3,\n",
" 'dropout_rate': 0.\n",
" },\n",
" class_conditional=False,\n",
" max_seq_length=500,\n",
" quantization_bits=8,\n",
" decoder_cross_attention=True,\n",
" use_discrete_vertex_embeddings=True,\n",
")\n",
"face_model_pred_dist = face_model(face_model_batch)\n",
"face_model_loss = -tf.reduce_sum(\n",
" face_model_pred_dist.log_prob(face_model_batch['faces']) * \n",
" face_model_batch['faces_mask'])\n",
"face_samples = face_model.sample(\n",
" context=vertex_samples, max_sample_length=500, top_p=0.95,\n",
" only_return_complete=False)\n",
"print(face_model_batch)\n",
"print(face_model_pred_dist)\n",
"print(face_samples)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "hL7yloXB1pUb"
},
"source": [
"## Train on the synthetic data\n",
"\n",
"Now that we've created vertex and face models and their respective data loaders, we can train them and look at some outputs. While we train the models together here, they can be trained seperately and recombined later if required. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "hjrbofa8zqQt"
},
"outputs": [],
"source": [
"# Optimization settings\n",
"learning_rate = 5e-4\n",
"training_steps = 500\n",
"check_step = 5\n",
"\n",
"# Create an optimizer an minimize the summed log probability of the mesh \n",
"# sequences\n",
"optimizer = tf.train.AdamOptimizer(learning_rate)\n",
"vertex_model_optim_op = optimizer.minimize(vertex_model_loss)\n",
"face_model_optim_op = optimizer.minimize(face_model_loss)\n",
"\n",
"# Training loop\n",
"with tf.Session() as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
" for n in range(training_steps):\n",
" if n % check_step == 0:\n",
" v_loss, f_loss = sess.run((vertex_model_loss, face_model_loss))\n",
" print('Step {}'.format(n))\n",
" print('Loss (vertices) {}'.format(v_loss))\n",
" print('Loss (faces) {}'.format(f_loss))\n",
" v_samples_np, f_samples_np, b_np = sess.run(\n",
" (vertex_samples, face_samples, vertex_model_batch))\n",
" mesh_list = []\n",
" for n in range(4):\n",
" mesh_list.append(\n",
" {\n",
" 'vertices': v_samples_np['vertices'][n][:v_samples_np['num_vertices'][n]],\n",
" 'faces': data_utils.unflatten_faces(\n",
" f_samples_np['faces'][n][:f_samples_np['num_face_indices'][n]])\n",
" }\n",
" )\n",
" data_utils.plot_meshes(mesh_list, ax_lims=0.5)\n",
" sess.run((vertex_model_optim_op, face_model_optim_op))"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "training.ipynb",
"provenance": [
{
"file_id": "1QL8ib2FKPGUWFQbuX8AttUk-H34Al8Ue",
"timestamp": 1591364245034
},
{
"file_id": "1v_7DtLnpXrEhVbwZhzDiVQW7ghroi11Y",
"timestamp": 1591355096822
}
],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}