mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-28 02:35:47 +08:00
Enable loading and processing OBJ meshes from memory.
Also remove dependency on modules from data_utils to make it more hermetic. PiperOrigin-RevId: 406120785
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
5f6b11299e
commit
d25484af7f
+52
-34
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
"""Mesh data utilities."""
|
"""Mesh data utilities."""
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import modules
|
|
||||||
from mpl_toolkits import mplot3d # pylint: disable=unused-import
|
from mpl_toolkits import mplot3d # pylint: disable=unused-import
|
||||||
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
@@ -88,8 +87,16 @@ def make_face_model_dataset(
|
|||||||
example['faces'] = tf.cast(
|
example['faces'] = tf.cast(
|
||||||
tf.gather(face_permutation, example['faces']), tf.int64)
|
tf.gather(face_permutation, example['faces']), tf.int64)
|
||||||
|
|
||||||
|
def _dequantize_verts(verts, n_bits):
|
||||||
|
min_range = -0.5
|
||||||
|
max_range = 0.5
|
||||||
|
range_quantize = 2**n_bits - 1
|
||||||
|
verts = tf.cast(verts, tf.float32)
|
||||||
|
verts = verts * (max_range - min_range) / range_quantize + min_range
|
||||||
|
return verts
|
||||||
|
|
||||||
# Vertices are quantized. So convert to floats for input to face model
|
# Vertices are quantized. So convert to floats for input to face model
|
||||||
example['vertices'] = modules.dequantize_verts(vertices, quantization_bits)
|
example['vertices'] = _dequantize_verts(vertices, quantization_bits)
|
||||||
example['vertices_mask'] = tf.ones_like(
|
example['vertices_mask'] = tf.ones_like(
|
||||||
example['vertices'][..., 0], dtype=tf.float32)
|
example['vertices'][..., 0], dtype=tf.float32)
|
||||||
example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32)
|
example['faces_mask'] = tf.ones_like(example['faces'], dtype=tf.float32)
|
||||||
@@ -97,44 +104,50 @@ def make_face_model_dataset(
|
|||||||
return ds.map(_face_model_map_fn)
|
return ds.map(_face_model_map_fn)
|
||||||
|
|
||||||
|
|
||||||
def read_obj(obj_path):
|
def read_obj_file(obj_file):
|
||||||
"""Read vertices and faces from .obj file."""
|
"""Read vertices and faces from already opened file."""
|
||||||
vertex_list = []
|
vertex_list = []
|
||||||
flat_vertices_list = []
|
flat_vertices_list = []
|
||||||
flat_vertices_indices = {}
|
flat_vertices_indices = {}
|
||||||
flat_triangles = []
|
flat_triangles = []
|
||||||
|
|
||||||
with open(obj_path) as obj_file:
|
for line in obj_file:
|
||||||
for line in obj_file:
|
tokens = line.split()
|
||||||
tokens = line.split()
|
if not tokens:
|
||||||
if not tokens:
|
continue
|
||||||
continue
|
line_type = tokens[0]
|
||||||
line_type = tokens[0]
|
# We skip lines not starting with v or f.
|
||||||
# We skip lines not starting with v or f.
|
if line_type == 'v':
|
||||||
if line_type == 'v':
|
vertex_list.append([float(x) for x in tokens[1:]])
|
||||||
vertex_list.append([float(x) for x in tokens[1:]])
|
elif line_type == 'f':
|
||||||
elif line_type == 'f':
|
triangle = []
|
||||||
triangle = []
|
for i in range(len(tokens) - 1):
|
||||||
for i in range(len(tokens) - 1):
|
vertex_name = tokens[i + 1]
|
||||||
vertex_name = tokens[i + 1]
|
if vertex_name in flat_vertices_indices:
|
||||||
if vertex_name in flat_vertices_indices:
|
triangle.append(flat_vertices_indices[vertex_name])
|
||||||
triangle.append(flat_vertices_indices[vertex_name])
|
continue
|
||||||
|
flat_vertex = []
|
||||||
|
for index in six.ensure_str(vertex_name).split('/'):
|
||||||
|
if not index:
|
||||||
continue
|
continue
|
||||||
flat_vertex = []
|
# obj triangle indices are 1 indexed, so subtract 1 here.
|
||||||
for index in six.ensure_str(vertex_name).split('/'):
|
flat_vertex += vertex_list[int(index) - 1]
|
||||||
if not index:
|
flat_vertex_index = len(flat_vertices_list)
|
||||||
continue
|
flat_vertices_list.append(flat_vertex)
|
||||||
# obj triangle indices are 1 indexed, so subtract 1 here.
|
flat_vertices_indices[vertex_name] = flat_vertex_index
|
||||||
flat_vertex += vertex_list[int(index) - 1]
|
triangle.append(flat_vertex_index)
|
||||||
flat_vertex_index = len(flat_vertices_list)
|
flat_triangles.append(triangle)
|
||||||
flat_vertices_list.append(flat_vertex)
|
|
||||||
flat_vertices_indices[vertex_name] = flat_vertex_index
|
|
||||||
triangle.append(flat_vertex_index)
|
|
||||||
flat_triangles.append(triangle)
|
|
||||||
|
|
||||||
return np.array(flat_vertices_list, dtype=np.float32), flat_triangles
|
return np.array(flat_vertices_list, dtype=np.float32), flat_triangles
|
||||||
|
|
||||||
|
|
||||||
|
def read_obj(obj_path):
|
||||||
|
"""Open .obj file from the path provided and read vertices and faces."""
|
||||||
|
|
||||||
|
with open(obj_path) as obj_file:
|
||||||
|
return read_obj_file(obj_file)
|
||||||
|
|
||||||
|
|
||||||
def write_obj(vertices, faces, file_path, transpose=True, scale=1.):
|
def write_obj(vertices, faces, file_path, transpose=True, scale=1.):
|
||||||
"""Write vertices and faces to obj."""
|
"""Write vertices and faces to obj."""
|
||||||
if transpose:
|
if transpose:
|
||||||
@@ -286,10 +299,8 @@ def quantize_process_mesh(vertices, faces, tris=None, quantization_bits=8):
|
|||||||
return vertices, faces, tris
|
return vertices, faces, tris
|
||||||
|
|
||||||
|
|
||||||
def load_process_mesh(mesh_obj_path, quantization_bits=8):
|
def process_mesh(vertices, faces, quantization_bits=8):
|
||||||
"""Load obj file and process."""
|
"""Process mesh vertices and faces."""
|
||||||
# Load mesh
|
|
||||||
vertices, faces = read_obj(mesh_obj_path)
|
|
||||||
|
|
||||||
# Transpose so that z-axis is vertical.
|
# Transpose so that z-axis is vertical.
|
||||||
vertices = vertices[:, [2, 0, 1]]
|
vertices = vertices[:, [2, 0, 1]]
|
||||||
@@ -316,6 +327,13 @@ def load_process_mesh(mesh_obj_path, quantization_bits=8):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_process_mesh(mesh_obj_path, quantization_bits=8):
|
||||||
|
"""Load obj file and process."""
|
||||||
|
# Load mesh
|
||||||
|
vertices, faces = read_obj(mesh_obj_path)
|
||||||
|
return process_mesh(vertices, faces, quantization_bits)
|
||||||
|
|
||||||
|
|
||||||
def plot_meshes(mesh_list,
|
def plot_meshes(mesh_list,
|
||||||
ax_lims=0.3,
|
ax_lims=0.3,
|
||||||
fig_size=4,
|
fig_size=4,
|
||||||
|
|||||||
Reference in New Issue
Block a user