Files
deepmind-research/catch_carry/mocap_data.py

188 lines
5.1 KiB
Python

# Copyright 2020 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Metadata for mocap clips that correspond to a walker carrying a prop."""
import collections
import enum
import os
from dm_control.locomotion.mocap import loader as mocap_loader
from catch_carry import trajectories
H5_DIR = os.path.dirname(__file__)
H5_PATH = os.path.join(H5_DIR, 'mocap_data.h5')
IDENTIFIER_PREFIX = 'DeepMindCatchCarry'
IDENTIFIER_TEMPLATE = IDENTIFIER_PREFIX + '-{:03d}'
ClipInfo = collections.namedtuple(
'ClipInfo', ('clip_identifier', 'num_steps', 'dt', 'flags'))
class Flag(enum.IntEnum):
BOX = 1 << 0
BALL = 1 << 1
LIGHT_PROP = 1 << 2
HEAVY_PROP = 1 << 3
SMALL_PROP = 1 << 4
LARGE_PROP = 1 << 5
FLOOR_LEVEL = 1 << 6
MEDIUM_PEDESTAL = 1 << 7
HIGH_PEDESTAL = 1 << 8
_ALL_CLIPS = None
def _get_clip_info(loader, clip_number, flags):
clip = loader.get_trajectory(IDENTIFIER_TEMPLATE.format(clip_number))
return ClipInfo(
clip_identifier=clip.identifier,
num_steps=clip.num_steps,
dt=clip.dt,
flags=flags)
def _get_all_clip_infos_if_necessary():
"""Creates the global _ALL_CLIPS list if it has not already been created."""
global _ALL_CLIPS
if _ALL_CLIPS is None:
loader = mocap_loader.HDF5TrajectoryLoader(
H5_PATH, trajectories.WarehouseTrajectory)
clip_numbers = (1, 2, 3, 4, 5, 6, 9, 10,
11, 12, 15, 16, 17, 18, 19, 20,
21, 22, 23, 24, 25, 26, 27, 28,
29, 30, 31, 32, 33, 34, 35, 36,
37, 38, 39, 40, 42, 43, 44, 45,
46, 47, 48, 49, 50, 51, 52, 53)
clip_infos = []
for i, clip_number in enumerate(clip_numbers):
flags = 0
if i in _FLOOR_LEVEL:
flags |= Flag.FLOOR_LEVEL
elif i in _MEDIUM_PEDESTAL:
flags |= Flag.MEDIUM_PEDESTAL
elif i in _HIGH_PEDESTAL:
flags |= Flag.HIGH_PEDESTAL
if i in _LIGHT_PROP:
flags |= Flag.LIGHT_PROP
elif i in _HEAVY_PROP:
flags |= Flag.HEAVY_PROP
if i in _SMALL_BOX:
flags |= Flag.SMALL_PROP
flags |= Flag.BOX
elif i in _LARGE_BOX:
flags |= Flag.LARGE_PROP
flags |= Flag.BOX
elif i in _SMALL_BALL:
flags |= Flag.SMALL_PROP
flags |= Flag.BALL
elif i in _LARGE_BALL:
flags |= Flag.LARGE_PROP
flags |= Flag.BALL
clip_infos.append(_get_clip_info(loader, clip_number, flags))
_ALL_CLIPS = tuple(clip_infos)
def _assert_partitions_all_clips(*args):
"""Asserts that a given set of subcollections partitions ALL_CLIPS."""
sets = tuple(set(arg) for arg in args)
# Check that the union of all the sets is ALL_CLIPS.
union = set()
for subset in sets:
union = union | set(subset)
assert union == set(range(48))
# Check that the sets are pairwise disjoint.
for i in range(len(sets)):
for j in range(i + 1, len(sets)):
assert sets[i] & sets[j] == set()
_FLOOR_LEVEL = tuple(range(0, 16))
_MEDIUM_PEDESTAL = tuple(range(16, 32))
_HIGH_PEDESTAL = tuple(range(32, 48))
_assert_partitions_all_clips(_FLOOR_LEVEL, _MEDIUM_PEDESTAL, _HIGH_PEDESTAL)
_LIGHT_PROP = (0, 1, 2, 3, 8, 9, 12, 13, 16, 17, 18, 19, 24,
25, 26, 27, 34, 35, 38, 39, 42, 43, 46, 47)
_HEAVY_PROP = (4, 5, 6, 7, 10, 11, 14, 15, 20, 21, 22, 23, 28,
29, 30, 31, 32, 33, 36, 37, 40, 41, 44, 45)
_assert_partitions_all_clips(_LIGHT_PROP, _HEAVY_PROP)
_SMALL_BOX = (0, 1, 4, 5, 16, 17, 20, 21, 34, 35, 36, 37)
_LARGE_BOX = (2, 3, 6, 7, 18, 19, 22, 23, 32, 33, 38, 39)
_SMALL_BALL = (8, 9, 10, 11, 24, 25, 30, 31, 40, 41, 46, 47)
_LARGE_BALL = (12, 13, 14, 15, 26, 27, 28, 29, 42, 43, 44, 45)
_assert_partitions_all_clips(_SMALL_BOX, _LARGE_BOX, _SMALL_BALL, _LARGE_BALL)
def all_clips():
_get_all_clip_infos_if_necessary()
return _ALL_CLIPS
def floor_level():
clips = all_clips()
return tuple(clips[i] for i in _FLOOR_LEVEL)
def medium_pedestal():
clips = all_clips()
return tuple(clips[i] for i in _MEDIUM_PEDESTAL)
def high_pedestal():
clips = all_clips()
return tuple(clips[i] for i in _HIGH_PEDESTAL)
def light_prop():
clips = all_clips()
return tuple(clips[i] for i in _LIGHT_PROP)
def heavy_prop():
clips = all_clips()
return tuple(clips[i] for i in _HEAVY_PROP)
def small_box():
clips = all_clips()
return tuple(clips[i] for i in _SMALL_BOX)
def large_box():
clips = all_clips()
return tuple(clips[i] for i in _LARGE_BOX)
def small_ball():
clips = all_clips()
return tuple(clips[i] for i in _SMALL_BALL)
def large_ball():
clips = all_clips()
return tuple(clips[i] for i in _LARGE_BALL)