Add code for generating s3o4d data

PiperOrigin-RevId: 535584807
This commit is contained in:
David Pfau
2023-05-26 12:49:41 +01:00
committed by Saran Tunyasuvunakool
parent 9176a9f23c
commit f5de0ede84

172
geomancer/data_writer.py Normal file
View File

@@ -0,0 +1,172 @@
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data writer for Stanford Bunny experiments and other objects."""
# pylint: disable=unused-import
import copy
import io
import os
import time
from absl import app
from absl import flags
from absl import logging
from dm_control import mujoco
import numpy as np
_SHARD = flags.DEFINE_integer('shard', 0, 'Shard index')
_SIZE = flags.DEFINE_integer('size', 1000,
'Number of images to save to a shard')
_OBJECT = flags.DEFINE_string('object', 'dragon', 'Which object to render')
_PATH = flags.DEFINE_string('path', '', 'Path to folder with .stl files')
render_height = 1024
render_width = 1024
height = 256
width = 256
def get_normal(x):
"""Get vectors normal to a unit vector."""
_, _, v = np.linalg.svd(x[None, :])
return v[:, 1:]
def render(quat, light, mesh='bunny', meshdir='data'):
"""Script to render an image."""
scale, pos = None, None
if mesh == 'bunny':
scale = 0.03
pos = -1.0
elif mesh == 'dragon':
scale = 0.06
pos = -0.3
simple_world_mjcf_template = """
<mujoco>
<visual>
<headlight active="0"/>
<global offwidth="%s" offheight="%s"/>
</visual>
<compiler meshdir="%s"/>
<asset>
<mesh name="%s" file="%s.stl" scale="%g %g %g"/>
</asset>
<worldbody>
<camera name="main" pos="0 0 5" xyaxes="1 0 0 0 1 0"/>
<body name="obj" quat="{} {} {} {}">
<geom name="%s" type="mesh" mesh="%s" pos="0 0 %g"/>
</body>
<light pos="{} {} {}" directional="true" dir="{} {} {}"/>
<light pos="{} {} {}" directional="true" dir="{} {} {}"/>
</worldbody>
</mujoco>
""" % (render_width, render_height, meshdir, mesh, mesh,
scale, scale, scale, mesh, mesh, pos)
light /= np.linalg.norm(light)
quat /= np.linalg.norm(quat)
simple_world_mjcf = simple_world_mjcf_template.format(
*(np.concatenate((quat,
5*light, -5*light,
-5*light, 5*light)).tolist()))
physics = mujoco.Physics.from_xml_string(simple_world_mjcf)
data = physics.render(camera_id='main',
height=render_height,
width=render_width).astype(np.float32)
data = data.reshape((width, int(render_width/width),
height, int(render_height/height), 3))
return np.mean(np.mean(data, axis=1), axis=2)
def get_tangent(quat, light, mesh='bunny', meshdir='data',
eps=0.03, use_light=True, use_quat=True):
"""Render image along with its tangent vectors by finite differences."""
assert use_light or use_quat
n = 0
light_tangent = None
quat_tangent = None
if use_light:
light_tangent = get_normal(light)
n += 2
if use_quat:
quat_tangent = get_normal(quat)
n += 3
# a triple-wide pixel buffer
try:
data = render(quat, light, mesh=mesh, meshdir=meshdir)
image_tangent = np.zeros((n, height, width, 3), dtype=np.float32)
if use_quat:
for i in range(3):
perturbed = render(quat + eps * quat_tangent[:, i],
light, mesh=mesh, meshdir=meshdir)
image_tangent[i] = (perturbed - data).astype(np.float32) / eps
if use_light:
j = 3 if use_quat else 0
for i in range(2):
perturbed = render(quat, light + eps * light_tangent[:, i],
mesh=mesh, meshdir=meshdir)
image_tangent[i+j] = (perturbed - data) / eps
image_tangent -= np.mean(image_tangent, axis=0)[None, ...]
latent_tangent = np.block(
[[quat_tangent, np.zeros((4, 2), dtype=np.float32)],
[np.zeros((3, 3), dtype=np.float32), light_tangent]])
return (np.mean(data, axis=-1),
np.mean(image_tangent, axis=-1),
latent_tangent)
except: # pylint: disable=bare-except
logging.info('Failed with latents (quat: %s, light: %s)', quat, light)
def main(_):
images = np.zeros((_SIZE.value, height, width), dtype=np.float32)
latents = np.zeros((_SIZE.value, 7), dtype=np.float32)
image_tangents = np.zeros((_SIZE.value, 5, height, width), dtype=np.float32)
latent_tangents = np.zeros((_SIZE.value, 7, 5), dtype=np.float32)
for i in range(_SIZE.value):
light = np.random.randn(3)
light /= np.linalg.norm(light)
quat = np.random.randn(4) # rotation represented as quaternion
quat /= np.linalg.norm(quat)
latents[i] = np.concatenate((light, quat))
images[i], image_tangents[i], latent_tangents[i] = (
get_tangent(quat, light, mesh=_OBJECT.value, meshdir=_PATH.value))
logging.info('Rendered image %d of %d', i, _SIZE.value)
os.makedirs(os.path.join(_PATH.value, _OBJECT.value), exist_ok=True)
with open(os.path.join(
_PATH.value, _OBJECT.value, 'shard_%03d.npz' % _SHARD.value), 'wb') as f:
io_buffer = io.BytesIO()
np.savez(io_buffer, images, latents, image_tangents, latent_tangents)
f.write(io_buffer.getvalue())
if __name__ == '__main__':
app.run(main)