Enable loading and processing OBJ meshes from memory.

Also remove dependency on modules from data_utils to make it more hermetic.

PiperOrigin-RevId: 406120785
This commit is contained in:
Tom Ward
2021-10-28 14:04:54 +01:00
committed by Saran Tunyasuvunakool
parent 5f6b11299e
commit d25484af7f
+52 -34
View File
@@ -14,7 +14,6 @@
"""Mesh data utilities.""" """Mesh data utilities."""
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import modules
from mpl_toolkits import mplot3d # pylint: disable=unused-import from mpl_toolkits import mplot3d # pylint: disable=unused-import
from mpl_toolkits.mplot3d.art3d import Poly3DCollection from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import networkx as nx import networkx as nx
@@ -88,8 +87,16 @@ def make_face_model_dataset(
example['faces'] = tf.cast( example['faces'] = tf.cast(
tf.gather(face_permutation, example['faces']), tf.int64) tf.gather(face_permutation, example['faces']), tf.int64)
def _dequantize_verts(verts, n_bits):
min_range = -0.5
max_range = 0.5
range_quantize = 2**n_bits - 1
verts = tf.cast(verts, tf.float32)
verts = verts * (max_range - min_range) / range_quantize + min_range
return verts
# Vertices are quantized. So convert to floats for input to face model # Vertices are quantized. So convert to floats for input to face model
example['vertices'] = modules.dequantize_verts(vertices, quantization_bits) example['vertices'] = _dequantize_verts(vertices, quantization_bits)
example['vertices_mask'] = tf.ones_like( example['vertices_mask'] = tf.ones_like(
example['vertices'][..., 0], dtype=tf.float32) example['vertices'][..., 0], dtype=tf.float32)
example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32) example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32)
@@ -97,44 +104,50 @@ def make_face_model_dataset(
return ds.map(_face_model_map_fn) return ds.map(_face_model_map_fn)
def read_obj(obj_path): def read_obj_file(obj_file):
"""Read vertices and faces from .obj file.""" """Read vertices and faces from already opened file."""
vertex_list = [] vertex_list = []
flat_vertices_list = [] flat_vertices_list = []
flat_vertices_indices = {} flat_vertices_indices = {}
flat_triangles = [] flat_triangles = []
with open(obj_path) as obj_file: for line in obj_file:
for line in obj_file: tokens = line.split()
tokens = line.split() if not tokens:
if not tokens: continue
continue line_type = tokens[0]
line_type = tokens[0] # We skip lines not starting with v or f.
# We skip lines not starting with v or f. if line_type == 'v':
if line_type == 'v': vertex_list.append([float(x) for x in tokens[1:]])
vertex_list.append([float(x) for x in tokens[1:]]) elif line_type == 'f':
elif line_type == 'f': triangle = []
triangle = [] for i in range(len(tokens) - 1):
for i in range(len(tokens) - 1): vertex_name = tokens[i + 1]
vertex_name = tokens[i + 1] if vertex_name in flat_vertices_indices:
if vertex_name in flat_vertices_indices: triangle.append(flat_vertices_indices[vertex_name])
triangle.append(flat_vertices_indices[vertex_name]) continue
flat_vertex = []
for index in six.ensure_str(vertex_name).split('/'):
if not index:
continue continue
flat_vertex = [] # obj triangle indices are 1 indexed, so subtract 1 here.
for index in six.ensure_str(vertex_name).split('/'): flat_vertex += vertex_list[int(index) - 1]
if not index: flat_vertex_index = len(flat_vertices_list)
continue flat_vertices_list.append(flat_vertex)
# obj triangle indices are 1 indexed, so subtract 1 here. flat_vertices_indices[vertex_name] = flat_vertex_index
flat_vertex += vertex_list[int(index) - 1] triangle.append(flat_vertex_index)
flat_vertex_index = len(flat_vertices_list) flat_triangles.append(triangle)
flat_vertices_list.append(flat_vertex)
flat_vertices_indices[vertex_name] = flat_vertex_index
triangle.append(flat_vertex_index)
flat_triangles.append(triangle)
return np.array(flat_vertices_list, dtype=np.float32), flat_triangles return np.array(flat_vertices_list, dtype=np.float32), flat_triangles
def read_obj(obj_path):
"""Open .obj file from the path provided and read vertices and faces."""
with open(obj_path) as obj_file:
return read_obj_file(obj_file)
def write_obj(vertices, faces, file_path, transpose=True, scale=1.): def write_obj(vertices, faces, file_path, transpose=True, scale=1.):
"""Write vertices and faces to obj.""" """Write vertices and faces to obj."""
if transpose: if transpose:
@@ -286,10 +299,8 @@ def quantize_process_mesh(vertices, faces, tris=None, quantization_bits=8):
return vertices, faces, tris return vertices, faces, tris
def load_process_mesh(mesh_obj_path, quantization_bits=8): def process_mesh(vertices, faces, quantization_bits=8):
"""Load obj file and process.""" """Process mesh vertices and faces."""
# Load mesh
vertices, faces = read_obj(mesh_obj_path)
# Transpose so that z-axis is vertical. # Transpose so that z-axis is vertical.
vertices = vertices[:, [2, 0, 1]] vertices = vertices[:, [2, 0, 1]]
@@ -316,6 +327,13 @@ def load_process_mesh(mesh_obj_path, quantization_bits=8):
} }
def load_process_mesh(mesh_obj_path, quantization_bits=8):
"""Load obj file and process."""
# Load mesh
vertices, faces = read_obj(mesh_obj_path)
return process_mesh(vertices, faces, quantization_bits)
def plot_meshes(mesh_list, def plot_meshes(mesh_list,
ax_lims=0.3, ax_lims=0.3,
fig_size=4, fig_size=4,