Files
deepmind-research/polygen/sample-pretrained.ipynb
T
Charlie Nash 22c3daff19 Adds PolyGen to public Deepmind Research Github repository
PiperOrigin-RevId: 327802622
2020-08-21 15:42:20 +01:00

295 lines
12 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "lUh_eWpealmh"
},
"source": [
"Copyright 2020 DeepMind Technologies Limited\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
" https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software\n",
"distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"See the License for the specific language governing permissions and\n",
"limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "kYd9gIfGJYZ8"
},
"source": [
"## Clone repo and import dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "Ux33ZDQ_tqUV"
},
"outputs": [],
"source": [
"!pip install tensorflow==1.15\n",
"!pip install dm-sonnet==1.36\n",
"!pip install tensor2tensor==1.14\n",
"\n",
"import time\n",
"import numpy as np\n",
"import tensorflow.compat.v1 as tf\n",
"tf.logging.set_verbosity(tf.logging.ERROR) # Hide TF deprecation messages\n",
"import matplotlib.pyplot as plt\n",
"\n",
"!git clone https://github.com/deepmind/deepmind-research.git deepmind_research\n",
"%cd deepmind_research/polygen\n",
"import modules\n",
"import data_utils"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "YUZNqHTVJbm3"
},
"source": [
"## Download pre-trained model weights from Google Cloud Storage"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "LpZjBUmq10gX"
},
"outputs": [],
"source": [
"!mkdir /tmp/vertex_model\n",
"!mkdir /tmp/face_model\n",
"!gsutil cp gs://deepmind-research-polygen/vertex_model.tar.gz /tmp/vertex_model/\n",
"!gsutil cp gs://deepmind-research-polygen/face_model.tar.gz /tmp/face_model/\n",
"!tar xvfz /tmp/vertex_model/vertex_model.tar.gz -C /tmp/vertex_model/\n",
"!tar xvfz /tmp/face_model/face_model.tar.gz -C /tmp/face_model/"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "vXhMOoxaJ3Xb"
},
"source": [
"## Pre-trained model config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "7VGSSS9vJSn-"
},
"outputs": [],
"source": [
"vertex_module_config=dict(\n",
" decoder_config=dict(\n",
" hidden_size=512,\n",
" fc_size=2048,\n",
" num_heads=8,\n",
" layer_norm=True,\n",
" num_layers=24,\n",
" dropout_rate=0.4,\n",
" re_zero=True,\n",
" memory_efficient=True\n",
" ),\n",
" quantization_bits=8,\n",
" class_conditional=True,\n",
" max_num_input_verts=5000,\n",
" use_discrete_embeddings=True,\n",
" )\n",
"\n",
"face_module_config=dict(\n",
" encoder_config=dict(\n",
" hidden_size=512,\n",
" fc_size=2048,\n",
" num_heads=8,\n",
" layer_norm=True,\n",
" num_layers=10,\n",
" dropout_rate=0.2,\n",
" re_zero=True,\n",
" memory_efficient=True,\n",
" ),\n",
" decoder_config=dict(\n",
" hidden_size=512,\n",
" fc_size=2048,\n",
" num_heads=8,\n",
" layer_norm=True,\n",
" num_layers=14,\n",
" dropout_rate=0.2,\n",
" re_zero=True,\n",
" memory_efficient=True,\n",
" ),\n",
" class_conditional=False,\n",
" decoder_cross_attention=True,\n",
" use_discrete_vertex_embeddings=True,\n",
" max_seq_length=8000,\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WNXf_XbuKW4S"
},
"source": [
"## Generate class-conditional samples\n",
"\n",
"Try varying the `class_id` parameter to generate meshes from different object categories. Good classes to try are tables (49), lamps (30), and cabinets (32). \n",
"\n",
"We can also specify the maximum number of vertices / face indices we want to see in the generated meshes using `max_num_vertices` and `max_num_face_indices`. The code will keep generating batches of samples until there are at least `num_samples_min` complete samples with the required number of vertices / faces.\n",
"\n",
"`top_p_vertex_model` and `top_p_face_model` control how varied the outputs are, with `1.` being the most varied, and `0.` the least varied. `0.9` is a good value for both the vertex and face models.\n",
"\n",
"Sampling should take around 2-5 minutes with a colab GPU using the default settings depending on the object class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "kqKMbPJJu3lk"
},
"outputs": [],
"source": [
"class_id = '49) table' #@param ['0) airplane,aeroplane,plane','1) ashcan,trash can,garbage can,wastebin,ash bin,ash-bin,ashbin,dustbin,trash barrel,trash bin','2) bag,traveling bag,travelling bag,grip,suitcase','3) basket,handbasket','4) bathtub,bathing tub,bath,tub','5) bed','6) bench','7) birdhouse','8) bookshelf','9) bottle','10) bowl','11) bus,autobus,coach,charabanc,double-decker,jitney,motorbus,motorcoach,omnibus,passenger vehi','12) cabinet','13) camera,photographic camera','14) can,tin,tin can','15) cap','16) car,auto,automobile,machine,motorcar','17) cellular telephone,cellular phone,cellphone,cell,mobile phone','18) chair','19) clock','20) computer keyboard,keypad','21) dishwasher,dish washer,dishwashing machine','22) display,video display','23) earphone,earpiece,headphone,phone','24) faucet,spigot','25) file,file cabinet,filing cabinet','26) guitar','27) helmet','28) jar','29) knife','30) lamp','31) laptop,laptop computer','32) loudspeaker,speaker,speaker unit,loudspeaker system,speaker system','33) mailbox,letter box','34) microphone,mike','35) microwave,microwave oven','36) motorcycle,bike','37) mug','38) piano,pianoforte,forte-piano','39) pillow','40) pistol,handgun,side arm,shooting iron','41) pot,flowerpot','42) printer,printing machine','43) remote control,remote','44) rifle','45) rocket,projectile','46) skateboard','47) sofa,couch,lounge','48) stove','49) table','50) telephone,phone,telephone set','51) tower','52) train,railroad train','53) vessel,watercraft','54) washer,automatic washer,washing machine']\n",
"num_samples_min = 1 #@param\n",
"num_samples_batch = 8 #@param\n",
"max_num_vertices = 400 #@param\n",
"max_num_face_indices = 2000 #@param\n",
"top_p_vertex_model = 0.9 #@param\n",
"top_p_face_model = 0.9 #@param\n",
"\n",
"tf.reset_default_graph()\n",
"\n",
"# Build models\n",
"vertex_model = modules.VertexModel(**vertex_module_config)\n",
"face_model = modules.FaceModel(**face_module_config)\n",
"\n",
"# Tile out class label to every element in batch\n",
"class_id = int(class_id.split(')')[0])\n",
"vertex_model_context = {'class_label': tf.fill([num_samples_batch,], class_id)}\n",
"vertex_samples = vertex_model.sample(\n",
" num_samples_batch, context=vertex_model_context, \n",
" max_sample_length=max_num_vertices, top_p=top_p_vertex_model, \n",
" recenter_verts=True, only_return_complete=True)\n",
"vertex_model_saver = tf.train.Saver(var_list=vertex_model.variables)\n",
"\n",
"# The face model generates samples conditioned on a context, which here is\n",
"# the vertex model samples\n",
"face_samples = face_model.sample(\n",
" vertex_samples, max_sample_length=max_num_face_indices, \n",
" top_p=top_p_face_model, only_return_complete=True)\n",
"face_model_saver = tf.train.Saver(var_list=face_model.variables)\n",
"\n",
"# Start sampling\n",
"start = time.time()\n",
"print('Generating samples...')\n",
"with tf.Session() as sess:\n",
" vertex_model_saver.restore(sess, '/tmp/vertex_model/model')\n",
" face_model_saver.restore(sess, '/tmp/face_model/model')\n",
" mesh_list = []\n",
" num_samples_complete = 0\n",
" while num_samples_complete \u003c num_samples_min:\n",
" v_samples_np = sess.run(vertex_samples)\n",
" if v_samples_np['completed'].size == 0:\n",
" print('No vertex samples completed in this batch. Try increasing ' +\n",
" 'max_num_vertices.')\n",
" continue\n",
" f_samples_np = sess.run(\n",
" face_samples,\n",
" {vertex_samples[k]: v_samples_np[k] for k in vertex_samples.keys()})\n",
" v_samples_np = f_samples_np['context']\n",
" num_samples_complete_batch = f_samples_np['completed'].sum()\n",
" num_samples_complete += num_samples_complete_batch\n",
" print('Num. samples complete: {}'.format(num_samples_complete))\n",
" for k in range(num_samples_complete_batch):\n",
" verts = v_samples_np['vertices'][k][:v_samples_np['num_vertices'][k]]\n",
" faces = data_utils.unflatten_faces(\n",
" f_samples_np['faces'][k][:f_samples_np['num_face_indices'][k]])\n",
" mesh_list.append({'vertices': verts, 'faces': faces})\n",
"end = time.time()\n",
"print('sampling time: {}'.format(end - start))\n",
"\n",
"data_utils.plot_meshes(mesh_list, ax_lims=0.4) "
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "OOQV6pMvSymz"
},
"source": [
"## Export meshes as `.obj` files\n",
"Pick a `mesh_id` (starting at 0) corresponding to the samples generated above. Refresh the colab file browser to find an `.obj` file with the mesh data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "fO0Klbq2Sx0m"
},
"outputs": [],
"source": [
"mesh_id = 4 #@param\n",
"data_utils.write_obj(\n",
" mesh_list[mesh_id]['vertices'], mesh_list[mesh_id]['faces'], \n",
" 'mesh-{}.obj'.format(mesh_id))"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "sampling-pretrained.ipynb",
"provenance": [
{
"file_id": "1yj-oHYqCnwYVGSM22coa68NqnSoNdyVr",
"timestamp": 1591622616999
},
{
"file_id": "1v_7DtLnpXrEhVbwZhzDiVQW7ghroi11Y",
"timestamp": 1591264007511
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}