diff --git a/README.md b/README.md
index 3e17a2f..3e8e804 100644
--- a/README.md
+++ b/README.md
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
## Projects
+* [Catch & Carry: Reusable Neural Controllers for Vision-Guided Whole-Body Tasks](catch_carry), SIGGRAPH 2020
* [MEMO: A Deep Network For Flexible Combination Of Episodic Memories](memo), ICLR 2020
* [RL Unplugged: Benchmarks for Offline Reinforcement Learning](rl_unplugged)
* [Disentangling by Subspace Diffusion (GEOMANCER)](geomancer)
diff --git a/catch_carry/README.md b/catch_carry/README.md
new file mode 100644
index 0000000..8bad38b
--- /dev/null
+++ b/catch_carry/README.md
@@ -0,0 +1,123 @@
+# Catch & Carry: Reusable Neural Controllers for Vision-Guided Whole-Body Tasks
+
+This package contains motion capture data and tasks associated with "Catch &
+Carry: Reusable Neural Controllers for Vision-Guided Whole-Body Tasks"
+(https://arxiv.org/abs/1911.06636), which was published at SIGGRAPH 2020.
+This is research code, and has dependencies on more stable code that is
+available as part of [`dm_control`], in particular upon components in
+[`dm_control.locomotion`].
+
+To get access to preconfigured python environments for the "warehouse" and "ball
+toss" tasks, see the `task_examples.py` file. To use the MuJoCo interactive
+viewer (from dm_control) to load the environments, see `explore.py`.
+
+
+
+
+
+## Installation instructions
+
+1. Download [MuJoCo Pro](https://mujoco.org/) and extract the zip archive as
+ `~/.mujoco/mujoco200_$PLATFORM` where `$PLATFORM` is one of `linux`,
+ `macos`, or `win64`.
+
+2. Ensure that a valid MuJoCo license key file is located at
+ `~/.mujoco/mjkey.txt`.
+
+3. Clone the `deepmind-research` repository:
+
+ ```shell
+ git clone https://github.com/deepmind/deepmind-research.git
+ cd deepmind-research
+ ```
+
+4. Create and activate a Python virtual environment:
+
+ ```shell
+ python3 -m virtualenv catch_carry
+ source catch_carry/bin/activate
+ ```
+
+5. Install the package:
+
+ ```shell
+ pip install ./catch_carry
+ ```
+
+## Quickstart
+
+To instantiate and step through the warehouse task:
+
+```python
+from catch_carry import task_examples
+import numpy as np
+
+# Build an example environment.
+env = task_examples.build_vision_warehouse()
+
+# Get the `action_spec` describing the control inputs.
+action_spec = env.action_spec()
+
+# Step through the environment for one episode with random actions.
+time_step = env.reset()
+while not time_step.last():
+ action = np.random.uniform(action_spec.minimum, action_spec.maximum,
+ size=action_spec.shape)
+ time_step = env.step(action)
+ print("reward = {}, discount = {}, observations = {}.".format(
+ time_step.reward, time_step.discount, time_step.observation))
+```
+
+The above code snippet can also be used for the ball toss task by replacing
+`build_vision_warehouse` with `build_vision_toss`.
+
+## Visualization
+
+[`dm_control.viewer`] can be used to visualize and interact with the
+environment. We provide the `explore.py` script specifically for this. If you
+followed our installation instructions above, this can be launched for the
+warehouse task via:
+
+```shell
+python3 -m catch_carry.explore --task=warehouse
+```
+
+and for the ball toss task via:
+
+```shell
+python3 -m catch_carry.explore --task=toss
+```
+
+## Citation
+
+If you use the code or data in this package, please cite:
+
+```
+@article{merel2020catch,
+ title = {Catch \& Carry: Reusable Neural Controllers for
+ Vision-Guided Whole-Body Tasks},
+ author = {Merel, Josh and
+ Tunyasuvunakool, Saran and
+ Ahuja, Arun and
+ Tassa, Yuval and
+ Hasenclever, Leonard and
+ Pham, Vu and
+ Erez, Tom and
+ Wayne, Greg and
+ Heess, Nicolas},
+ journal = {ACM Trans. Graph.},
+ issue_date = {July 2020},
+ publisher = {Association for Computing Machinery},
+ address = {New York, NY, USA},
+ volume = {39},
+ number = {4},
+ numpages = {14},
+ issn = {0730-0301},
+ year = {2020},
+ month = jul,
+}
+```
+
+[`dm_control`]: https://github.com/deepmind/dm_control
+[`dm_control.locomotion`]: https://github.com/deepmind/dm_control/tree/master/dm_control/locomotion
+[`dm_control.viewer`]: https://github.com/deepmind/dm_control/tree/master/dm_control/viewer
diff --git a/catch_carry/__init__.py b/catch_carry/__init__.py
new file mode 100644
index 0000000..175aba9
--- /dev/null
+++ b/catch_carry/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2020 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/catch_carry/arm_opener.py b/catch_carry/arm_opener.py
new file mode 100644
index 0000000..dbac8a2
--- /dev/null
+++ b/catch_carry/arm_opener.py
@@ -0,0 +1,171 @@
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utility for opening arms until they are not in contact with a prop."""
+
+import contextlib
+
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+
+_MAX_IK_ATTEMPTS = 100
+_IK_MAX_CORRECTION_WEIGHT = 0.1
+_JOINT_LIMIT_TOLERANCE = 1e-4
+_GAP_TOLERANCE = 0.1
+
+
+class _ArmPropContactRemover(object):
+ """Helper class for removing contacts between an arm and a prop via IK."""
+
+ def __init__(self, physics, arm_root, prop, gap):
+ arm_geoms = arm_root.find_all('geom')
+ self._arm_geom_ids = set(physics.bind(arm_geoms).element_id)
+ arm_joints = arm_root.find_all('joint')
+ self._arm_joint_ids = list(physics.bind(arm_joints).element_id)
+ self._arm_qpos_indices = physics.model.jnt_qposadr[self._arm_joint_ids]
+ self._arm_dof_indices = physics.model.jnt_dofadr[self._arm_joint_ids]
+
+ self._prop_geoms = prop.find_all('geom')
+ self._prop_geom_ids = set(physics.bind(self._prop_geoms).element_id)
+
+ self._arm_joint_min = np.full(len(self._arm_joint_ids), float('-inf'),
+ dtype=physics.model.jnt_range.dtype)
+ self._arm_joint_max = np.full(len(self._arm_joint_ids), float('inf'),
+ dtype=physics.model.jnt_range.dtype)
+ for i, joint_id in enumerate(self._arm_joint_ids):
+ if physics.model.jnt_limited[joint_id]:
+ self._arm_joint_min[i], self._arm_joint_max[i] = (
+ physics.model.jnt_range[joint_id])
+
+ self._gap = gap
+
+ def _contact_pair_is_relevant(self, contact):
+ set1 = self._arm_geom_ids
+ set2 = self._prop_geom_ids
+ return ((contact.geom1 in set1 and contact.geom2 in set2) or
+ (contact.geom2 in set1 and contact.geom1 in set2))
+
+ def _forward_and_find_next_contact(self, physics):
+ """Forwards the physics and finds the next contact to handle."""
+ physics.forward()
+ next_contact = None
+ for contact in physics.data.contact:
+ if (self._contact_pair_is_relevant(contact) and
+ (next_contact is None or contact.dist < next_contact.dist)):
+ next_contact = contact
+ return next_contact
+
+ def _remove_contact_ik_iteration(self, physics, contact):
+ """Performs one linearized IK iteration to remove the specified contact."""
+ if contact.geom1 in self._arm_geom_ids:
+ sign = -1
+ geom_id = contact.geom1
+ else:
+ sign = 1
+ geom_id = contact.geom2
+
+ body_id = physics.model.geom_bodyid[geom_id]
+ normal = sign * contact.frame[:3]
+
+ jac_dtype = physics.data.qpos.dtype
+ jac = np.empty((6, physics.model.nv), dtype=jac_dtype)
+ jac_pos, jac_rot = jac[:3], jac[3:]
+ mjbindings.mjlib.mj_jacPointAxis(
+ physics.model.ptr, physics.data.ptr,
+ jac_pos, jac_rot,
+ contact.pos + (contact.dist / 2) * normal, normal, body_id)
+
+ # Calculate corrections w.r.t. all joints, disregarding joint limits.
+ delta_xpos = normal * max(0, self._gap - contact.dist)
+ jac_all_joints = jac_pos[:, self._arm_dof_indices]
+ update_unfiltered = np.linalg.lstsq(
+ jac_all_joints, delta_xpos, rcond=None)[0]
+
+ # Filter out joints at limit that are corrected in the "wrong" direction.
+ initial_qpos = np.array(physics.data.qpos[self._arm_qpos_indices])
+ min_filter = np.logical_and(
+ initial_qpos - self._arm_joint_min < _JOINT_LIMIT_TOLERANCE,
+ update_unfiltered < 0)
+ max_filter = np.logical_and(
+ self._arm_joint_max - initial_qpos < _JOINT_LIMIT_TOLERANCE,
+ update_unfiltered > 0)
+ active_joints = np.where(
+ np.logical_not(np.logical_or(min_filter, max_filter)))[0]
+
+ # Calculate corrections w.r.t. valid joints only.
+ active_dof_indices = self._arm_dof_indices[active_joints]
+ jac_joints = jac_pos[:, active_dof_indices]
+ update_filtered = np.linalg.lstsq(jac_joints, delta_xpos, rcond=None)[0]
+ update_nv = np.zeros(physics.model.nv, dtype=jac_dtype)
+ update_nv[active_dof_indices] = update_filtered
+
+ # Calculate maximum correction weight that does not violate joint limits.
+ weights = np.full_like(update_filtered, _IK_MAX_CORRECTION_WEIGHT)
+ active_initial_qpos = initial_qpos[active_joints]
+ active_joint_min = self._arm_joint_min[active_joints]
+ active_joint_max = self._arm_joint_max[active_joints]
+ for i in range(len(weights)):
+ proposed_update = update_filtered[i]
+ if proposed_update > 0:
+ max_allowed_update = active_joint_max[i] - active_initial_qpos[i]
+ weights[i] = min(max_allowed_update / proposed_update, weights[i])
+ elif proposed_update < 0:
+ min_allowed_update = active_joint_min[i] - active_initial_qpos[i]
+ weights[i] = min(min_allowed_update / proposed_update, weights[i])
+ weight = min(weights)
+
+ # Integrate the correction into `qpos`.
+ mjbindings.mjlib.mj_integratePos(
+ physics.model.ptr, physics.data.qpos, update_nv, weight)
+
+ # "Paranoid" clip the modified joint `qpos` to within joint limits.
+ active_qpos_indices = self._arm_qpos_indices[active_joints]
+ physics.data.qpos[active_qpos_indices] = np.clip(
+ physics.data.qpos[active_qpos_indices],
+ active_joint_min, active_joint_max)
+
+ @contextlib.contextmanager
+ def _override_margins_and_gaps(self, physics):
+ """Context manager that overrides geom margins and gaps to `self._gap`."""
+ prop_geom_bindings = physics.bind(self._prop_geoms)
+ original_margins = np.array(prop_geom_bindings.margin)
+ original_gaps = np.array(prop_geom_bindings.gap)
+ prop_geom_bindings.margin = self._gap * (1 - _GAP_TOLERANCE)
+ prop_geom_bindings.gap = self._gap * (1 - _GAP_TOLERANCE)
+ yield
+ prop_geom_bindings.margin = original_margins
+ prop_geom_bindings.gap = original_gaps
+ physics.forward()
+
+ def remove_contacts(self, physics):
+ with self._override_margins_and_gaps(physics):
+ for _ in range(_MAX_IK_ATTEMPTS):
+ contact = self._forward_and_find_next_contact(physics)
+ if contact is None:
+ return
+ self._remove_contact_ik_iteration(physics, contact)
+ contact = self._forward_and_find_next_contact(physics)
+ if contact and contact.dist < 0:
+ raise RuntimeError(
+ 'Failed to remove contact with prop after {} iterations. '
+ 'Final contact distance is {}.'.format(
+ _MAX_IK_ATTEMPTS, contact.dist))
+
+
+def open_arms_for_prop(physics, left_arm_root, right_arm_root, prop, gap):
+ """Opens left and right arms so as to leave a specified gap with the prop."""
+ left_arm_opener = _ArmPropContactRemover(physics, left_arm_root, prop, gap)
+ left_arm_opener.remove_contacts(physics)
+ right_arm_opener = _ArmPropContactRemover(physics, right_arm_root, prop, gap)
+ right_arm_opener.remove_contacts(physics)
diff --git a/catch_carry/ball_toss.py b/catch_carry/ball_toss.py
new file mode 100644
index 0000000..f0c7799
--- /dev/null
+++ b/catch_carry/ball_toss.py
@@ -0,0 +1,319 @@
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A ball-tossing task."""
+
+import collections
+
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer import variation
+from dm_control.composer.observation import observable
+from dm_control.locomotion.arenas import floors
+from dm_control.locomotion.mocap import loader as mocap_loader
+import numpy as np
+
+from catch_carry import mocap_data
+from catch_carry import props
+from catch_carry import trajectories
+
+_PHYSICS_TIMESTEP = 0.005
+
+_BUCKET_SIZE = (0.2, 0.2, 0.02)
+
+# Magnitude of the sparse reward.
+_SPARSE_REWARD = 1.0
+
+
+class BallToss(composer.Task):
+ """A task involving catching and throwing a ball."""
+
+ def __init__(self, walker,
+ proto_modifier=None,
+ negative_reward_on_failure_termination=True,
+ priority_friction=False,
+ bucket_offset=1.,
+ y_range=0.5,
+ toss_delay=0.5,
+ randomize_init=False,
+ ):
+ """Initialize ball tossing task.
+
+ Args:
+ walker: the walker to be used in this task.
+ proto_modifier: function to modify trajectory proto.
+ negative_reward_on_failure_termination: flag to provide negative reward
+ as task fails.
+ priority_friction: sets friction priority thereby making prop objects have
+ higher friction.
+ bucket_offset: distance in meters to push bucket (away from walker)
+ y_range: range (uniformly sampled) of distance in meters the ball is
+ thrown left/right of the walker.
+ toss_delay: time in seconds to delay after catching before changing reward
+ to encourage throwing the ball.
+ randomize_init: flag to randomize initial pose.
+ """
+ self._proto_modifier = proto_modifier
+ self._negative_reward_on_failure_termination = (
+ negative_reward_on_failure_termination)
+ self._priority_friction = priority_friction
+ self._bucket_rewarded = False
+ self._bucket_offset = bucket_offset
+ self._y_range = y_range
+ self._toss_delay = toss_delay
+ self._randomize_init = randomize_init
+
+ # load a clip to grab a ball prop and initializations
+ loader = mocap_loader.HDF5TrajectoryLoader(
+ mocap_data.H5_PATH, trajectories.WarehouseTrajectory)
+ clip_number = 54
+ self._trajectory = loader.get_trajectory(
+ mocap_data.IDENTIFIER_TEMPLATE.format(clip_number))
+
+ # create the floor arena
+ self._arena = floors.Floor()
+
+ self._walker = walker
+ self._walker_geoms = tuple(self._walker.mjcf_model.find_all('geom'))
+ self._feet_geoms = (
+ walker.mjcf_model.find('body', 'lfoot').find_all('geom') +
+ walker.mjcf_model.find('body', 'rfoot').find_all('geom'))
+ self._lhand_geoms = (
+ walker.mjcf_model.find('body', 'lhand').find_all('geom'))
+ self._rhand_geoms = (
+ walker.mjcf_model.find('body', 'rhand').find_all('geom'))
+
+ # resize the humanoid based on the motion capture data subject
+ self._trajectory.configure_walkers([self._walker])
+ walker.create_root_joints(self._arena.attach(walker))
+
+ control_timestep = self._trajectory.dt
+ self.set_timesteps(control_timestep, _PHYSICS_TIMESTEP)
+
+ # build and attach the bucket to the arena
+ self._bucket = props.Bucket(_BUCKET_SIZE)
+ self._arena.attach(self._bucket)
+
+ self._prop = self._trajectory.create_props(
+ priority_friction=self._priority_friction)[0]
+ self._arena.add_free_entity(self._prop)
+
+ self._task_observables = collections.OrderedDict()
+
+ # define feature based observations (agent may or may not use these)
+ def ego_prop_xpos(physics):
+ prop_xpos, _ = self._prop.get_pose(physics)
+ walker_xpos = physics.bind(self._walker.root_body).xpos
+ return self._walker.transform_vec_to_egocentric_frame(
+ physics, prop_xpos - walker_xpos)
+ self._task_observables['prop_{}/xpos'.format(0)] = (
+ observable.Generic(ego_prop_xpos))
+
+ def prop_zaxis(physics):
+ prop_xmat = physics.bind(
+ mjcf.get_attachment_frame(self._prop.mjcf_model)).xmat
+ return prop_xmat[[2, 5, 8]]
+ self._task_observables['prop_{}/zaxis'.format(0)] = (
+ observable.Generic(prop_zaxis))
+
+ def ego_bucket_xpos(physics):
+ bucket_xpos, _ = self._bucket.get_pose(physics)
+ walker_xpos = physics.bind(self._walker.root_body).xpos
+ return self._walker.transform_vec_to_egocentric_frame(
+ physics, bucket_xpos - walker_xpos)
+ self._task_observables['bucket_{}/xpos'.format(0)] = (
+ observable.Generic(ego_bucket_xpos))
+
+ for obs in (self._walker.observables.proprioception +
+ self._walker.observables.kinematic_sensors +
+ self._walker.observables.dynamic_sensors +
+ list(self._task_observables.values())):
+ obs.enabled = True
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ @property
+ def task_observables(self):
+ return self._task_observables
+
+ @property
+ def name(self):
+ return 'ball_toss'
+
+ def initialize_episode_mjcf(self, random_state):
+ self._reward = 0.0
+ self._discount = 1.0
+ self._should_terminate = False
+
+ self._prop.detach()
+
+ if self._proto_modifier:
+ trajectory = self._trajectory.get_modified_trajectory(
+ self._proto_modifier)
+
+ self._prop = trajectory.create_props(
+ priority_friction=self._priority_friction)[0]
+ self._arena.add_free_entity(self._prop)
+
+ # set the bucket position for this episode
+ bucket_distance = 1.*random_state.rand()+self._bucket_offset
+ mjcf.get_attachment_frame(self._bucket.mjcf_model).pos = [bucket_distance,
+ 0, 0]
+
+ def initialize_episode(self, physics, random_state):
+ self._ground_geomid = physics.bind(
+ self._arena.mjcf_model.worldbody.geom[0]).element_id
+ self._feet_geomids = set(physics.bind(self._feet_geoms).element_id)
+ self._lhand_geomids = set(physics.bind(self._lhand_geoms).element_id)
+ self._rhand_geomids = set(physics.bind(self._rhand_geoms).element_id)
+ self._walker_geomids = set(physics.bind(self._walker_geoms).element_id)
+ self._bucket_rewarded = False
+
+ if self._randomize_init:
+ timestep_ind = random_state.randint(
+ len(self._trajectory._proto.timesteps)) # pylint: disable=protected-access
+ else:
+ timestep_ind = 0
+ walker_init_timestep = self._trajectory._proto.timesteps[timestep_ind] # pylint: disable=protected-access
+ prop_init_timestep = self._trajectory._proto.timesteps[0] # pylint: disable=protected-access
+
+ self._walker.set_pose(
+ physics,
+ position=walker_init_timestep.walkers[0].position,
+ quaternion=walker_init_timestep.walkers[0].quaternion)
+ self._walker.set_velocity(
+ physics, velocity=walker_init_timestep.walkers[0].velocity,
+ angular_velocity=walker_init_timestep.walkers[0].angular_velocity)
+ physics.bind(self._walker.mocap_joints).qpos = (
+ walker_init_timestep.walkers[0].joints)
+ physics.bind(self._walker.mocap_joints).qvel = (
+ walker_init_timestep.walkers[0].joints_velocity)
+
+ initial_prop_pos = np.copy(prop_init_timestep.props[0].position)
+ initial_prop_pos[0] += 1. # move ball (from mocap) relative to origin
+ initial_prop_pos[1] = 0 # align ball with walker along y-axis
+ self._prop.set_pose(
+ physics,
+ position=initial_prop_pos,
+ quaternion=prop_init_timestep.props[0].quaternion)
+
+ # specify the distributions of ball velocity componentwise
+ x_vel_mag = 4.5*random_state.rand()+1.5 # m/s
+ x_dist = 3 # approximate initial distance from walker to ball
+ self._t_dist = x_dist/x_vel_mag # target time at which to hit the humanoid
+ z_offset = .4*random_state.rand()+.1 # height at which to hit person
+ # compute velocity to satisfy desired projectile trajectory
+ z_vel_mag = (4.9*(self._t_dist**2) + z_offset)/self._t_dist
+
+ y_range = variation.evaluate(self._y_range, random_state=random_state)
+ y_vel_mag = y_range*random_state.rand()-y_range/2
+ trans_vel = [-x_vel_mag, y_vel_mag, z_vel_mag]
+ ang_vel = 1.5*random_state.rand(3)-0.75
+ self._prop.set_velocity(
+ physics,
+ velocity=trans_vel,
+ angular_velocity=ang_vel)
+
+ def after_step(self, physics, random_state):
+ # First we check for failure termination (walker or ball touches ground).
+ ground_failure = False
+ for contact in physics.data.contact:
+ if ((contact.geom1 == self._ground_geomid and
+ contact.geom2 not in self._feet_geomids) or
+ (contact.geom2 == self._ground_geomid and
+ contact.geom1 not in self._feet_geomids)):
+ ground_failure = True
+ break
+
+ contact_features = self._evaluate_contacts(physics)
+ prop_lhand, prop_rhand, bucket_prop, bucket_walker, walker_prop = contact_features
+
+ # or also fail if walker hits bucket
+ if ground_failure or bucket_walker:
+ if self._negative_reward_on_failure_termination:
+ self._reward = -_SPARSE_REWARD
+ else:
+ self._reward = 0.0
+ self._should_terminate = True
+ self._discount = 0.0
+ return
+
+ self._reward = 0.0
+ # give reward if prop is in bucket (prop touching bottom surface of bucket)
+ if bucket_prop:
+ self._reward += _SPARSE_REWARD/10
+
+ # shaping reward for being closer to bucket
+ if physics.data.time > (self._t_dist + self._toss_delay):
+ bucket_xy = physics.bind(self._bucket.geom).xpos[0][:2]
+ prop_xy = self._prop.get_pose(physics)[0][:2]
+ xy_dist = np.sum(np.array(np.abs(bucket_xy - prop_xy)))
+ self._reward += np.exp(-xy_dist/3.)*_SPARSE_REWARD/50
+ else:
+ # bonus for hands touching ball
+ if prop_lhand:
+ self._reward += _SPARSE_REWARD/100
+ if prop_rhand:
+ self._reward += _SPARSE_REWARD/100
+ # combined with penalty for other body parts touching the ball
+ if walker_prop:
+ self._reward -= _SPARSE_REWARD/100
+
+ def get_reward(self, physics):
+ return self._reward
+
+ def get_discount(self, physics):
+ return self._discount
+
+ def should_terminate_episode(self, physics):
+ return self._should_terminate
+
+ def _evaluate_contacts(self, physics):
+ prop_elem_id = physics.bind(self._prop.geom).element_id
+ bucket_bottom_elem_id = physics.bind(self._bucket.geom[0]).element_id
+ bucket_any_elem_id = set(physics.bind(self._bucket.geom).element_id)
+ prop_lhand_contact = False
+ prop_rhand_contact = False
+ bucket_prop_contact = False
+ bucket_walker_contact = False
+ walker_prop_contact = False
+
+ for contact in physics.data.contact:
+ has_prop = (contact.geom1 == prop_elem_id or
+ contact.geom2 == prop_elem_id)
+ has_bucket_bottom = (contact.geom1 == bucket_bottom_elem_id or
+ contact.geom2 == bucket_bottom_elem_id)
+ has_bucket_any = (contact.geom1 in bucket_any_elem_id or
+ contact.geom2 in bucket_any_elem_id)
+ has_lhand = (contact.geom1 in self._lhand_geomids or
+ contact.geom2 in self._lhand_geomids)
+ has_rhand = (contact.geom1 in self._rhand_geomids or
+ contact.geom2 in self._rhand_geomids)
+ has_walker = (contact.geom1 in self._walker_geomids or
+ contact.geom2 in self._walker_geomids)
+ if has_prop and has_bucket_bottom:
+ bucket_prop_contact = True
+ if has_walker and has_bucket_any:
+ bucket_walker_contact = True
+ if has_walker and has_prop:
+ walker_prop_contact = True
+ if has_prop and has_lhand:
+ prop_lhand_contact = True
+ if has_prop and has_rhand:
+ prop_rhand_contact = True
+
+ return (prop_lhand_contact, prop_rhand_contact, bucket_prop_contact,
+ bucket_walker_contact, walker_prop_contact)
diff --git a/catch_carry/explore.py b/catch_carry/explore.py
new file mode 100644
index 0000000..adcdd9c
--- /dev/null
+++ b/catch_carry/explore.py
@@ -0,0 +1,37 @@
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Simple script to launch viewer with an example environment."""
+
+from absl import app
+from absl import flags
+from dm_control import viewer
+from catch_carry import task_examples
+
+FLAGS = flags.FLAGS
+flags.DEFINE_enum('task', 'warehouse', ['warehouse', 'toss'],
+ 'The task to visualize.')
+
+TASKS = {
+ 'warehouse': task_examples.build_vision_warehouse,
+ 'toss': task_examples.build_vision_toss,
+}
+
+
+def main(unused_argv):
+ viewer.launch(environment_loader=TASKS[FLAGS.task])
+
+if __name__ == '__main__':
+ app.run(main)
+
diff --git a/catch_carry/mocap_data.h5 b/catch_carry/mocap_data.h5
new file mode 100644
index 0000000..ec8c3a8
Binary files /dev/null and b/catch_carry/mocap_data.h5 differ
diff --git a/catch_carry/mocap_data.py b/catch_carry/mocap_data.py
new file mode 100644
index 0000000..1bcdef1
--- /dev/null
+++ b/catch_carry/mocap_data.py
@@ -0,0 +1,187 @@
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Metadata for mocap clips that correspond to a walker carrying a prop."""
+
+import collections
+import enum
+import os
+
+from dm_control.locomotion.mocap import loader as mocap_loader
+
+
+from catch_carry import trajectories
+
+H5_DIR = os.path.dirname(__file__)
+H5_PATH = os.path.join(H5_DIR, 'mocap_data.h5')
+
+IDENTIFIER_PREFIX = 'DeepMindCatchCarry'
+IDENTIFIER_TEMPLATE = IDENTIFIER_PREFIX + '-{:03d}'
+
+ClipInfo = collections.namedtuple(
+ 'ClipInfo', ('clip_identifier', 'num_steps', 'dt', 'flags'))
+
+
+class Flag(enum.IntEnum):
+ BOX = 1 << 0
+ BALL = 1 << 1
+ LIGHT_PROP = 1 << 2
+ HEAVY_PROP = 1 << 3
+ SMALL_PROP = 1 << 4
+ LARGE_PROP = 1 << 5
+ FLOOR_LEVEL = 1 << 6
+ MEDIUM_PEDESTAL = 1 << 7
+ HIGH_PEDESTAL = 1 << 8
+
+
+_ALL_CLIPS = None
+
+
+def _get_clip_info(loader, clip_number, flags):
+ clip = loader.get_trajectory(IDENTIFIER_TEMPLATE.format(clip_number))
+ return ClipInfo(
+ clip_identifier=clip.identifier,
+ num_steps=clip.num_steps,
+ dt=clip.dt,
+ flags=flags)
+
+
+def _get_all_clip_infos_if_necessary():
+ """Creates the global _ALL_CLIPS list if it has not already been created."""
+ global _ALL_CLIPS
+ if _ALL_CLIPS is None:
+ loader = mocap_loader.HDF5TrajectoryLoader(
+ H5_PATH, trajectories.WarehouseTrajectory)
+ clip_numbers = (1, 2, 3, 4, 5, 6, 9, 10,
+ 11, 12, 15, 16, 17, 18, 19, 20,
+ 21, 22, 23, 24, 25, 26, 27, 28,
+ 29, 30, 31, 32, 33, 34, 35, 36,
+ 37, 38, 39, 40, 42, 43, 44, 45,
+ 46, 47, 48, 49, 50, 51, 52, 53)
+
+ clip_infos = []
+ for i, clip_number in enumerate(clip_numbers):
+ flags = 0
+
+ if i in _FLOOR_LEVEL:
+ flags |= Flag.FLOOR_LEVEL
+ elif i in _MEDIUM_PEDESTAL:
+ flags |= Flag.MEDIUM_PEDESTAL
+ elif i in _HIGH_PEDESTAL:
+ flags |= Flag.HIGH_PEDESTAL
+
+ if i in _LIGHT_PROP:
+ flags |= Flag.LIGHT_PROP
+ elif i in _HEAVY_PROP:
+ flags |= Flag.HEAVY_PROP
+
+ if i in _SMALL_BOX:
+ flags |= Flag.SMALL_PROP
+ flags |= Flag.BOX
+ elif i in _LARGE_BOX:
+ flags |= Flag.LARGE_PROP
+ flags |= Flag.BOX
+ elif i in _SMALL_BALL:
+ flags |= Flag.SMALL_PROP
+ flags |= Flag.BALL
+ elif i in _LARGE_BALL:
+ flags |= Flag.LARGE_PROP
+ flags |= Flag.BALL
+ clip_infos.append(_get_clip_info(loader, clip_number, flags))
+
+ _ALL_CLIPS = tuple(clip_infos)
+
+
+def _assert_partitions_all_clips(*args):
+ """Asserts that a given set of subcollections partitions ALL_CLIPS."""
+ sets = tuple(set(arg) for arg in args)
+
+ # Check that the union of all the sets is ALL_CLIPS.
+ union = set()
+ for subset in sets:
+ union = union | set(subset)
+ assert union == set(range(48))
+
+ # Check that the sets are pairwise disjoint.
+ for i in range(len(sets)):
+ for j in range(i + 1, len(sets)):
+ assert sets[i] & sets[j] == set()
+
+
+_FLOOR_LEVEL = tuple(range(0, 16))
+_MEDIUM_PEDESTAL = tuple(range(16, 32))
+_HIGH_PEDESTAL = tuple(range(32, 48))
+_assert_partitions_all_clips(_FLOOR_LEVEL, _MEDIUM_PEDESTAL, _HIGH_PEDESTAL)
+
+_LIGHT_PROP = (0, 1, 2, 3, 8, 9, 12, 13, 16, 17, 18, 19, 24,
+ 25, 26, 27, 34, 35, 38, 39, 42, 43, 46, 47)
+_HEAVY_PROP = (4, 5, 6, 7, 10, 11, 14, 15, 20, 21, 22, 23, 28,
+ 29, 30, 31, 32, 33, 36, 37, 40, 41, 44, 45)
+_assert_partitions_all_clips(_LIGHT_PROP, _HEAVY_PROP)
+
+_SMALL_BOX = (0, 1, 4, 5, 16, 17, 20, 21, 34, 35, 36, 37)
+_LARGE_BOX = (2, 3, 6, 7, 18, 19, 22, 23, 32, 33, 38, 39)
+_SMALL_BALL = (8, 9, 10, 11, 24, 25, 30, 31, 40, 41, 46, 47)
+_LARGE_BALL = (12, 13, 14, 15, 26, 27, 28, 29, 42, 43, 44, 45)
+_assert_partitions_all_clips(_SMALL_BOX, _LARGE_BOX, _SMALL_BALL, _LARGE_BALL)
+
+
+def all_clips():
+ _get_all_clip_infos_if_necessary()
+ return _ALL_CLIPS
+
+
+def floor_level():
+ clips = all_clips()
+ return tuple(clips[i] for i in _FLOOR_LEVEL)
+
+
+def medium_pedestal():
+ clips = all_clips()
+ return tuple(clips[i] for i in _MEDIUM_PEDESTAL)
+
+
+def high_pedestal():
+ clips = all_clips()
+ return tuple(clips[i] for i in _HIGH_PEDESTAL)
+
+
+def light_prop():
+ clips = all_clips()
+ return tuple(clips[i] for i in _LIGHT_PROP)
+
+
+def heavy_prop():
+ clips = all_clips()
+ return tuple(clips[i] for i in _HEAVY_PROP)
+
+
+def small_box():
+ clips = all_clips()
+ return tuple(clips[i] for i in _SMALL_BOX)
+
+
+def large_box():
+ clips = all_clips()
+ return tuple(clips[i] for i in _LARGE_BOX)
+
+
+def small_ball():
+ clips = all_clips()
+ return tuple(clips[i] for i in _SMALL_BALL)
+
+
+def large_ball():
+ clips = all_clips()
+ return tuple(clips[i] for i in _LARGE_BALL)
diff --git a/catch_carry/props.py b/catch_carry/props.py
new file mode 100644
index 0000000..9c2c40d
--- /dev/null
+++ b/catch_carry/props.py
@@ -0,0 +1,86 @@
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A rectangular pedestal."""
+
+from dm_control import composer
+from dm_control import mjcf
+
+
+class Pedestal(composer.Entity):
+ """A rectangular pedestal."""
+
+ def _build(self, size=(.2, .3, .05), rgba=(0, .5, 0, 1), name='pedestal'):
+ self._mjcf_root = mjcf.RootElement(model=name)
+ self._geom = self._mjcf_root.worldbody.add(
+ 'geom', type='box', size=size, name='geom', rgba=rgba)
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def geom(self):
+ return self._geom
+
+ def after_compile(self, physics, unused_random_state):
+ super(Pedestal, self).after_compile(physics, unused_random_state)
+ self._body_geom_ids = set(
+ physics.bind(geom).element_id
+ for geom in self.mjcf_model.find_all('geom'))
+
+ @property
+ def body_geom_ids(self):
+ return self._body_geom_ids
+
+
+class Bucket(composer.Entity):
+ """A rectangular bucket."""
+
+ def _build(self, size=(.2, .3, .05), rgba=(0, .5, 0, 1), name='pedestal'):
+ self._mjcf_root = mjcf.RootElement(model=name)
+ self._geoms = []
+ self._geoms.append(self._mjcf_root.worldbody.add(
+ 'geom', type='box', size=size, name='geom_bottom', rgba=rgba))
+ self._geoms.append(self._mjcf_root.worldbody.add(
+ 'geom', type='box', size=(size[2], size[1], size[0]), name='geom_s1',
+ rgba=rgba, pos=[size[0], 0, size[0]]))
+ self._geoms.append(self._mjcf_root.worldbody.add(
+ 'geom', type='box', size=(size[2], size[1], size[0]), name='geom_s2',
+ rgba=rgba, pos=[-size[0], 0, size[0]]))
+ self._geoms.append(self._mjcf_root.worldbody.add(
+ 'geom', type='box', size=(size[0], size[2], size[0]), name='geom_s3',
+ rgba=rgba, pos=[0, size[1], size[0]]))
+ self._geoms.append(self._mjcf_root.worldbody.add(
+ 'geom', type='box', size=(size[0], size[2], size[0]), name='geom_s4',
+ rgba=rgba, pos=[0, -size[1], size[0]]))
+
+ @property
+ def mjcf_model(self):
+ return self._mjcf_root
+
+ @property
+ def geom(self):
+ return self._geoms
+
+ def after_compile(self, physics, unused_random_state):
+ super(Bucket, self).after_compile(physics, unused_random_state)
+ self._body_geom_ids = set(
+ physics.bind(geom).element_id
+ for geom in self.mjcf_model.find_all('geom'))
+
+ @property
+ def body_geom_ids(self):
+ return self._body_geom_ids
+
diff --git a/catch_carry/setup.py b/catch_carry/setup.py
new file mode 100644
index 0000000..541750d
--- /dev/null
+++ b/catch_carry/setup.py
@@ -0,0 +1,35 @@
+# Copyright 2020 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Setup for pip package."""
+
+from setuptools import find_packages
+from setuptools import setup
+
+
+REQUIRED_PACKAGES = ['absl-py', 'dm_control', 'numpy']
+
+setup(
+ name='catch_carry',
+ version='0.1',
+ description='Whole-body object manipulation tasks and motion capture data.',
+ url='https://github.com/deepmind/deepmind-research/catch_carry',
+ author='DeepMind',
+ author_email='stunya@google.com',
+ # Contained modules and scripts.
+ packages=find_packages(),
+ install_requires=REQUIRED_PACKAGES,
+ platforms=['any'],
+ license='Apache 2.0',
+)
diff --git a/catch_carry/task_examples.py b/catch_carry/task_examples.py
new file mode 100644
index 0000000..a3f81d4
--- /dev/null
+++ b/catch_carry/task_examples.py
@@ -0,0 +1,82 @@
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Functions that build representative tasks."""
+
+from dm_control import composer
+from dm_control.composer.variation import distributions
+from dm_control.locomotion.mocap import loader as mocap_loader
+from dm_control.locomotion.walkers import cmu_humanoid
+
+from catch_carry import ball_toss
+from catch_carry import warehouse
+
+
+def build_vision_warehouse(random_state=None):
+ """Build canonical 4-pedestal, 2-prop task."""
+
+ # Build a position-controlled CMU humanoid walker.
+ walker = cmu_humanoid.CMUHumanoidPositionControlled(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build the task.
+ size_distribution = distributions.Uniform(low=0.75, high=1.25)
+ mass_distribution = distributions.Uniform(low=2, high=7)
+ prop_resizer = mocap_loader.PropResizer(size_factor=size_distribution,
+ mass=mass_distribution)
+ task = warehouse.PhasedBoxCarry(
+ walker=walker,
+ num_props=2,
+ num_pedestals=4,
+ proto_modifier=prop_resizer,
+ negative_reward_on_failure_termination=True)
+
+ # return the environment
+ return composer.Environment(
+ time_limit=15,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True,
+ max_reset_attempts=float('inf'))
+
+
+def build_vision_toss(random_state=None):
+ """Build canonical ball tossing task."""
+
+ # Build a position-controlled CMU humanoid walker.
+ walker = cmu_humanoid.CMUHumanoidPositionControlled(
+ observable_options={'egocentric_camera': dict(enabled=True)})
+
+ # Build the task.
+ size_distribution = distributions.Uniform(low=0.95, high=1.5)
+ mass_distribution = distributions.Uniform(low=2, high=4)
+ prop_resizer = mocap_loader.PropResizer(size_factor=size_distribution,
+ mass=mass_distribution)
+ task = ball_toss.BallToss(
+ walker=walker,
+ proto_modifier=prop_resizer,
+ negative_reward_on_failure_termination=True,
+ priority_friction=True,
+ bucket_offset=3.,
+ y_range=0.5,
+ toss_delay=1.5,
+ randomize_init=True)
+
+ # return the environment
+ return composer.Environment(
+ time_limit=6,
+ task=task,
+ random_state=random_state,
+ strip_singleton_obs_buffer_dim=True,
+ max_reset_attempts=float('inf'))
diff --git a/catch_carry/tasks.png b/catch_carry/tasks.png
new file mode 100644
index 0000000..337bda8
Binary files /dev/null and b/catch_carry/tasks.png differ
diff --git a/catch_carry/trajectories.py b/catch_carry/trajectories.py
new file mode 100644
index 0000000..547d708
--- /dev/null
+++ b/catch_carry/trajectories.py
@@ -0,0 +1,225 @@
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Mocap trajectory that assumes props start stationary on pedestals."""
+
+import copy
+import enum
+import itertools
+
+from dm_control.locomotion.mocap import mocap_pb2
+from dm_control.locomotion.mocap import trajectory
+from dm_control.utils import transformations
+import numpy as np
+
+_PEDESTAL_SIZE = (0.2, 0.2, 0.02)
+_MAX_SETTLE_STEPS = 100
+
+
+@enum.unique
+class ClipSegment(enum.Enum):
+ """Annotations for subsegments within a warehouse clips."""
+
+ # Clip segment corresponding to a walker approaching an object
+ APPROACH = 1
+
+ # Clip segment corresponding to a walker picking up an object.
+ PICKUP = 2
+
+ # Clip segment corresponding to the "first half" of the walker carrying an
+ # object, beginning from the walker backing away from a pedestal with
+ # object in hand.
+ CARRY1 = 3
+
+ # Clip segment corresponding to the "second half" of the walker carrying an
+ # object, ending in the walker approaching a pedestal the object in hand.
+ CARRY2 = 4
+
+ # Clip segment corresponding to a walker putting down an object on a pedestal.
+ PUTDOWN = 5
+
+ # Clip segment corresponding to a walker backing off after successfully
+ # placing an object on a pedestal.
+ BACKOFF = 6
+
+
+def _get_rotated_bounding_box(size, quaternion):
+ """Calculates the bounding box of a rotated 3D box.
+
+ Args:
+ size: An array of length 3 specifying the half-lengths of a box.
+ quaternion: A unit quaternion specifying the box's orientation.
+
+ Returns:
+ An array of length 3 specifying the half-lengths of the bounding box of
+ the rotated box.
+ """
+ corners = ((size[0], size[1], size[2]),
+ (size[0], size[1], -size[2]),
+ (size[0], -size[1], size[2]),
+ (-size[0], size[1], size[2]))
+ rotated_corners = tuple(
+ transformations.quat_rotate(quaternion, corner) for corner in corners)
+ return np.amax(np.abs(rotated_corners), axis=0)
+
+
+def _get_prop_z_extent(prop_proto, quaternion):
+ """Calculates the "z-extent" of the prop in given orientation.
+
+ This is the distance from the centre of the prop to its lowest point in the
+ world frame, taking into account the prop's orientation.
+
+ Args:
+ prop_proto: A `mocap_pb2.Prop` protocol buffer defining a prop.
+ quaternion: A unit quaternion specifying the prop's orientation.
+
+ Returns:
+ the distance from the centre of the prop to its lowest point in the
+ world frame in the specified orientation.
+ """
+ if prop_proto.shape == mocap_pb2.Prop.BOX:
+ return _get_rotated_bounding_box(prop_proto.size, quaternion)[2]
+ elif prop_proto.shape == mocap_pb2.Prop.SPHERE:
+ return prop_proto.size[0]
+ else:
+ raise NotImplementedError(
+ 'Unsupported prop shape: {}'.format(prop_proto.shape))
+
+
+class WarehouseTrajectory(trajectory.Trajectory):
+ """Mocap trajectory that assumes props start stationary on pedestals."""
+
+ def infer_pedestal_positions(self, num_averaged_steps=30,
+ ground_height_tolerance=0.1,
+ proto_modifier=None):
+ proto = self._proto
+ if proto_modifier is not None:
+ proto = copy.copy(proto)
+ proto_modifier(proto)
+
+ if not proto.props:
+ return []
+
+ positions = []
+ for timestep in itertools.islice(proto.timesteps, num_averaged_steps):
+ positions_for_timestep = []
+ for prop_proto, prop_timestep in zip(proto.props, timestep.props):
+ z_extent = _get_prop_z_extent(prop_proto, prop_timestep.quaternion)
+ positions_for_timestep.append([prop_timestep.position[0],
+ prop_timestep.position[1],
+ prop_timestep.position[2] - z_extent])
+ positions.append(positions_for_timestep)
+
+ median_positions = np.median(positions, axis=0)
+ median_positions[:, 2][median_positions[:, 2] < ground_height_tolerance] = 0
+ return median_positions
+
+ def get_props_z_extent(self, physics):
+ timestep = self._proto.timesteps[self._get_step_id(physics.time())]
+ out = []
+ for prop_proto, prop_timestep in zip(self._proto.props, timestep.props):
+ z_extent = _get_prop_z_extent(prop_proto, prop_timestep.quaternion)
+ out.append(z_extent)
+ return out
+
+
+class SinglePropCarrySegmentedTrajectory(WarehouseTrajectory):
+ """A mocap trajectory class that automatically segments prop-carry clips.
+
+ The algorithm implemented in the class only works if the trajectory consists
+ of exactly one walker and one prop. The value of `pedestal_zone_distance`
+ the exact nature of zone crossings are determined empirically from the
+ DeepMindCatchCarry dataset, and are likely to not work well outside of this
+ setting.
+ """
+
+ def __init__(self,
+ proto,
+ start_time=None,
+ end_time=None,
+ pedestal_zone_distance=0.65,
+ start_step=None,
+ end_step=None,
+ zero_out_velocities=True):
+ super(SinglePropCarrySegmentedTrajectory, self).__init__(
+ proto, start_time, end_time, start_step=start_step, end_step=end_step,
+ zero_out_velocities=zero_out_velocities)
+ self._pedestal_zone_distance = pedestal_zone_distance
+ self._generate_segments()
+
+ def _generate_segments(self):
+ pedestal_position = self.infer_pedestal_positions()[0]
+
+ # First we find the timesteps at which the walker cross the pedestal's
+ # vicinity zone. This should happen exactly 4 times: enter it to pick up,
+ # leave it, enter it again to put down, and leave it again.
+ was_in_pedestal_zone = False
+ crossings = []
+ for i, timestep in enumerate(self._proto.timesteps):
+ pedestal_dist = np.linalg.norm(
+ timestep.walkers[0].position[:2] - pedestal_position[:2])
+ if pedestal_dist > self._pedestal_zone_distance and was_in_pedestal_zone:
+ crossings.append(i)
+ was_in_pedestal_zone = False
+ elif (pedestal_dist <= self._pedestal_zone_distance and
+ not was_in_pedestal_zone):
+ crossings.append(i)
+ was_in_pedestal_zone = True
+ if len(crossings) < 3:
+ raise RuntimeError(
+ 'Failed to segment the given trajectory: '
+ 'walker should cross the pedestal zone\'s boundary >= 3 times '
+ 'but got {}'.format(len(crossings)))
+ elif len(crossings) == 3:
+ crossings.append(len(self._proto.timesteps) - 1)
+ elif len(crossings) > 4:
+ crossings = [crossings[0], crossings[1], crossings[-2], crossings[-1]]
+
+ # Identify the pick up event during the first in-zone interval.
+ start_position = np.array(self._proto.timesteps[0].props[0].position)
+ end_position = np.array(self._proto.timesteps[-1].props[0].position)
+ pick_up_step = crossings[1] - 1
+ while pick_up_step > crossings[0]:
+ prev_position = self._proto.timesteps[pick_up_step - 1].props[0].position
+ if np.linalg.norm(start_position[2] - prev_position[2]) < 0.001:
+ break
+ pick_up_step -= 1
+
+ # Identify the put down event during the second in-zone interval.
+ put_down_step = crossings[2]
+ while put_down_step <= crossings[3]:
+ next_position = self._proto.timesteps[put_down_step + 1].props[0].position
+ if np.linalg.norm(end_position[2] - next_position[2]) < 0.001:
+ break
+ put_down_step += 1
+
+ carry_halfway_step = int((crossings[1] + crossings[2]) / 2)
+
+ self._segment_intervals = {
+ ClipSegment.APPROACH: (0, crossings[0]),
+ ClipSegment.PICKUP: (crossings[0], pick_up_step),
+ ClipSegment.CARRY1: (pick_up_step, carry_halfway_step),
+ ClipSegment.CARRY2: (carry_halfway_step, crossings[2]),
+ ClipSegment.PUTDOWN: (crossings[2], put_down_step),
+ ClipSegment.BACKOFF: (put_down_step, len(self._proto.timesteps))
+ }
+
+ def segment_interval(self, segment):
+ start_step, end_step = self._segment_intervals[segment]
+ return (start_step * self._proto.dt, (end_step - 1) * self._proto.dt)
+
+ def get_random_timestep_in_segment(self, segment, random_step):
+ return self._proto.timesteps[
+ random_step.randint(*self._segment_intervals[segment])]
+
diff --git a/catch_carry/warehouse.py b/catch_carry/warehouse.py
new file mode 100644
index 0000000..7f83f19
--- /dev/null
+++ b/catch_carry/warehouse.py
@@ -0,0 +1,681 @@
+# Copyright 2020 Deepmind Technologies Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A prop-carry task that transition between multiple phases."""
+
+import collections
+import colorsys
+import enum
+
+from absl import logging
+from dm_control import composer
+from dm_control import mjcf
+from dm_control.composer.observation import observable
+from dm_control.locomotion.arenas import floors
+from dm_control.locomotion.mocap import loader as mocap_loader
+from dm_control.mujoco.wrapper import mjbindings
+import numpy as np
+
+from catch_carry import arm_opener
+from catch_carry import mocap_data
+from catch_carry import props
+from catch_carry import trajectories
+
+_PHYSICS_TIMESTEP = 0.005
+
+# Maximum number of physics steps to run when settling props onto pedestals
+# during episode initialization.
+_MAX_SETTLE_STEPS = 1000
+
+# Maximum velocity for prop to be considered settled.
+# Used during episode initialization only.
+_SETTLE_QVEL_TOL = 1e-5
+
+# Magnitude of the sparse reward.
+_SPARSE_REWARD = 1.0
+
+# Maximum distance for walkers to be considered to be "near" a pedestal/target.
+_TARGET_TOL = 0.65
+
+# Defines how pedestals are placed around the arena.
+# Pedestals are placed at constant angle intervals around the arena's center.
+_BASE_PEDESTAL_DIST = 3 # Base distance from center.
+_PEDESTAL_DIST_DELTA = 0.5 # Maximum variation on the base distance.
+
+# Base hue-luminosity-saturation of the pedestal colors.
+# We rotate through the hue for each pedestal created in the environment.
+_BASE_PEDESTAL_H = 0.1
+_BASE_PEDESTAL_L = 0.3
+_BASE_PEDESTAL_S = 0.7
+
+# Pedestal luminosity when active.
+_ACTIVATED_PEDESTAL_L = 0.8
+
+_PEDESTAL_SIZE = (0.2, 0.2, 0.02)
+
+_SINGLE_PEDESTAL_COLOR = colorsys.hls_to_rgb(.3, .15, .35) + (1.0,)
+
+WALKER_PEDESTAL = 'walker_pedestal'
+WALKER_PROP = 'walker_prop'
+PROP_PEDESTAL = 'prop_pedestal'
+TARGET_STATE = 'target_state/'
+CURRENT_STATE = 'meta/current_state/'
+
+
+def _is_same_state(state_1, state_2):
+ if state_1.keys() != state_2.keys():
+ return False
+ for k in state_1:
+ if not np.all(state_1[k] == state_2[k]):
+ return False
+ return True
+
+
+def _singleton_or_none(iterable):
+ iterator = iter(iterable)
+ try:
+ return next(iterator)
+ except StopIteration:
+ return None
+
+
+def _generate_pedestal_colors(num_pedestals):
+ """Function to get colors for pedestals."""
+ colors = []
+ for i in range(num_pedestals):
+ h = _BASE_PEDESTAL_H + i / num_pedestals
+ while h > 1:
+ h -= 1
+ colors.append(
+ colorsys.hls_to_rgb(h, _BASE_PEDESTAL_L, _BASE_PEDESTAL_S) + (1.0,))
+ return colors
+
+
+InitializationParameters = collections.namedtuple(
+ 'InitializationParameters', ('clip_segment', 'prop_id', 'pedestal_id'))
+
+
+def _rotate_vector_by_quaternion(vec, quat):
+ result = np.empty(3)
+ mjbindings.mjlib.mju_rotVecQuat(result, np.asarray(vec), np.asarray(quat))
+ return result
+
+
+@enum.unique
+class WarehousePhase(enum.Enum):
+ TERMINATED = 0
+ GOTOTARGET = 1
+ PICKUP = 2
+ CARRYTOTARGET = 3
+ PUTDOWN = 4
+
+
+def _find_random_free_pedestal_id(target_state, random_state):
+ free_pedestals = (
+ np.where(np.logical_not(np.any(target_state, axis=0)))[0])
+ return random_state.choice(free_pedestals)
+
+
+def _find_random_occupied_pedestal_id(target_state, random_state):
+ occupied_pedestals = (
+ np.where(np.any(target_state, axis=0))[0])
+ return random_state.choice(occupied_pedestals)
+
+
+def one_hot(values, num_unique):
+ return np.squeeze(np.eye(num_unique)[np.array(values).reshape(-1)])
+
+
+class SinglePropFourPhases(object):
+ """A phase manager that transitions between four phases for a single prop."""
+
+ def __init__(self, fixed_initialization_phase=None):
+ self._phase = WarehousePhase.TERMINATED
+ self._fixed_initialization_phase = fixed_initialization_phase
+
+ def initialize_episode(self, target_state, random_state):
+ """Randomly initializes an episode into one of the four phases."""
+
+ if self._fixed_initialization_phase is None:
+ self._phase = random_state.choice([
+ WarehousePhase.GOTOTARGET, WarehousePhase.PICKUP,
+ WarehousePhase.CARRYTOTARGET, WarehousePhase.PUTDOWN
+ ])
+ else:
+ self._phase = self._fixed_initialization_phase
+ self._prop_id = random_state.randint(len(target_state[PROP_PEDESTAL]))
+ self._pedestal_id = np.nonzero(
+ target_state[PROP_PEDESTAL][self._prop_id])[0][0]
+ pedestal_id_for_initialization = self._pedestal_id
+
+ if self._phase == WarehousePhase.GOTOTARGET:
+ clip_segment = trajectories.ClipSegment.APPROACH
+ target_state[WALKER_PROP][:] = 0
+ target_state[WALKER_PEDESTAL][self._pedestal_id] = 1
+ elif self._phase == WarehousePhase.PICKUP:
+ clip_segment = trajectories.ClipSegment.PICKUP
+ target_state[WALKER_PROP][self._prop_id] = 1
+ target_state[WALKER_PEDESTAL][self._pedestal_id] = 1
+ # Set self._pedestal_id to the next pedestal after pickup is successful.
+ self._pedestal_id = _find_random_free_pedestal_id(
+ target_state[PROP_PEDESTAL], random_state)
+ target_state[PROP_PEDESTAL][self._prop_id, :] = 0
+ elif self._phase == WarehousePhase.CARRYTOTARGET:
+ clip_segment = random_state.choice([
+ trajectories.ClipSegment.CARRY1, trajectories.ClipSegment.CARRY2])
+ self._pedestal_id = _find_random_free_pedestal_id(
+ target_state[PROP_PEDESTAL], random_state)
+ if clip_segment == trajectories.ClipSegment.CARRY2:
+ pedestal_id_for_initialization = self._pedestal_id
+ target_state[WALKER_PROP][self._prop_id] = 1
+ target_state[WALKER_PEDESTAL][self._pedestal_id] = 1
+ target_state[PROP_PEDESTAL][self._prop_id, :] = 0
+ elif self._phase == WarehousePhase.PUTDOWN:
+ clip_segment = trajectories.ClipSegment.PUTDOWN
+ target_state[WALKER_PROP][:] = 0
+ target_state[WALKER_PEDESTAL][self._pedestal_id] = 1
+
+ return InitializationParameters(
+ clip_segment, self._prop_id, pedestal_id_for_initialization)
+
+ def on_success(self, target_state, random_state):
+ """Transitions into the next phase upon success of current phase."""
+ if self._phase == WarehousePhase.GOTOTARGET:
+ if self._prop_id is not None:
+ self._phase = WarehousePhase.PICKUP
+ # Set self._pedestal_id to the next pedestal after pickup is successful.
+ self._pedestal_id = (
+ _find_random_free_pedestal_id(
+ target_state[PROP_PEDESTAL], random_state))
+ target_state[WALKER_PROP][self._prop_id] = 1
+ target_state[PROP_PEDESTAL][self._prop_id, :] = 0
+ else:
+ # If you go to an empty pedestal, go to pedestal with a prop.
+ self._pedestal_id = (
+ _find_random_occupied_pedestal_id(
+ target_state[PROP_PEDESTAL], random_state))
+ target_state[WALKER_PEDESTAL][:] = 0
+ target_state[WALKER_PEDESTAL][self._pedestal_id] = 1
+ self._prop_id = np.argwhere(
+ target_state[PROP_PEDESTAL][:, self._pedestal_id])[0, 0]
+ elif self._phase == WarehousePhase.PICKUP:
+ self._phase = WarehousePhase.CARRYTOTARGET
+ target_state[WALKER_PEDESTAL][:] = 0
+ target_state[WALKER_PEDESTAL][self._pedestal_id] = 1
+ elif self._phase == WarehousePhase.CARRYTOTARGET:
+ self._phase = WarehousePhase.PUTDOWN
+ target_state[WALKER_PROP][:] = 0
+ target_state[PROP_PEDESTAL][self._prop_id, self._pedestal_id] = 1
+ elif self._phase == WarehousePhase.PUTDOWN:
+ self._phase = WarehousePhase.GOTOTARGET
+ # Set self._pedestal_id to the next pedestal after putdown is successful.
+ self._pedestal_id = (
+ _find_random_free_pedestal_id(
+ target_state[PROP_PEDESTAL], random_state))
+ self._prop_id = None
+ target_state[WALKER_PEDESTAL][:] = 0
+ target_state[WALKER_PEDESTAL][self._pedestal_id] = 1
+ return self._phase
+
+ @property
+ def phase(self):
+ return self._phase
+
+ @property
+ def prop_id(self):
+ return self._prop_id
+
+ @property
+ def pedestal_id(self):
+ return self._pedestal_id
+
+
+class PhasedBoxCarry(composer.Task):
+ """A prop-carry task that transitions between multiple phases."""
+
+ def __init__(
+ self,
+ walker,
+ num_props,
+ num_pedestals,
+ proto_modifier=None,
+ transition_class=SinglePropFourPhases,
+ min_prop_gap=0.05,
+ pedestal_height_range=(0.45, 0.75),
+ log_transitions=False,
+ negative_reward_on_failure_termination=True,
+ use_single_pedestal_color=True,
+ priority_friction=False,
+ fixed_initialization_phase=None):
+ """Initialize phased/instructed box-carrying ("warehouse") task.
+
+ Args:
+ walker: the walker to be used in this task.
+ num_props: the number of props in the task scene.
+ num_pedestals: the number of floating shelves (pedestals) in the task
+ scene.
+ proto_modifier: function to modify trajectory proto.
+ transition_class: the object that handles the transition logic.
+ min_prop_gap: arms are automatically opened to leave a gap around the prop
+ to avoid problematic collisions upon initialization.
+ pedestal_height_range: range of heights for the pedestal.
+ log_transitions: logging/printing of transitions.
+ negative_reward_on_failure_termination: boolean for whether to provide
+ negative sparse rewards on failure termination.
+ use_single_pedestal_color: boolean option for pedestals being the same
+ color or different colors.
+ priority_friction: sets friction priority thereby making prop objects have
+ higher friction.
+ fixed_initialization_phase: an instance of the `WarehousePhase` enum that
+ specifies the phase in which to always initialize the task, or `None` if
+ the initial task phase should be chosen randomly for each episode.
+ """
+ self._num_props = num_props
+ self._num_pedestals = num_pedestals
+ self._proto_modifier = proto_modifier
+ self._transition_manager = transition_class(
+ fixed_initialization_phase=fixed_initialization_phase)
+ self._min_prop_gap = min_prop_gap
+ self._pedestal_height_range = pedestal_height_range
+ self._log_transitions = log_transitions
+ self._target_state = collections.OrderedDict([
+ (WALKER_PEDESTAL, np.zeros(num_pedestals)),
+ (WALKER_PROP, np.zeros(num_props)),
+ (PROP_PEDESTAL, np.zeros([num_props, num_pedestals]))
+ ])
+ self._current_state = collections.OrderedDict([
+ (WALKER_PEDESTAL, np.zeros(num_pedestals)),
+ (WALKER_PROP, np.zeros(num_props)),
+ (PROP_PEDESTAL, np.zeros([num_props, num_pedestals]))
+ ])
+ self._negative_reward_on_failure_termination = (
+ negative_reward_on_failure_termination)
+ self._priority_friction = priority_friction
+
+ clips = sorted(
+ set(mocap_data.medium_pedestal())
+ & (set(mocap_data.small_box()) | set(mocap_data.large_box())))
+ loader = mocap_loader.HDF5TrajectoryLoader(
+ mocap_data.H5_PATH, trajectories.SinglePropCarrySegmentedTrajectory)
+ self._trajectories = [
+ loader.get_trajectory(clip.clip_identifier) for clip in clips]
+
+ self._arena = floors.Floor()
+
+ self._walker = walker
+ self._feet_geoms = (
+ walker.mjcf_model.find('body', 'lfoot').find_all('geom') +
+ walker.mjcf_model.find('body', 'rfoot').find_all('geom'))
+ self._lhand_geoms = (
+ walker.mjcf_model.find('body', 'lhand').find_all('geom'))
+ self._rhand_geoms = (
+ walker.mjcf_model.find('body', 'rhand').find_all('geom'))
+ self._trajectories[0].configure_walkers([self._walker])
+ walker.create_root_joints(self._arena.attach(walker))
+
+ control_timestep = self._trajectories[0].dt
+ for i, trajectory in enumerate(self._trajectories):
+ if trajectory.dt != control_timestep:
+ raise ValueError(
+ 'Inconsistent control timestep: '
+ 'trajectories[{}].dt == {} but trajectories[0].dt == {}'
+ .format(i, trajectory.dt, control_timestep))
+ self.set_timesteps(control_timestep, _PHYSICS_TIMESTEP)
+
+ if use_single_pedestal_color:
+ self._pedestal_colors = [_SINGLE_PEDESTAL_COLOR] * num_pedestals
+ else:
+ self._pedestal_colors = _generate_pedestal_colors(num_pedestals)
+ self._pedestals = [props.Pedestal(_PEDESTAL_SIZE, rgba)
+ for rgba in self._pedestal_colors]
+ for pedestal in self._pedestals:
+ self._arena.attach(pedestal)
+
+ self._props = [
+ self._trajectories[0].create_props(
+ priority_friction=self._priority_friction)[0]
+ for _ in range(num_props)
+ ]
+ for prop in self._props:
+ self._arena.add_free_entity(prop)
+
+ self._task_observables = collections.OrderedDict()
+
+ self._task_observables['target_phase'] = observable.Generic(
+ lambda _: one_hot(self._transition_manager.phase.value, num_unique=5))
+
+ def ego_prop_xpos(physics):
+ prop_id = self._focal_prop_id
+ if prop_id is None:
+ return np.zeros((3,))
+ prop = self._props[prop_id]
+ prop_xpos, _ = prop.get_pose(physics)
+ walker_xpos = physics.bind(self._walker.root_body).xpos
+ return self._walker.transform_vec_to_egocentric_frame(
+ physics, prop_xpos - walker_xpos)
+ self._task_observables['target_prop/xpos'] = (
+ observable.Generic(ego_prop_xpos))
+
+ def prop_zaxis(physics):
+ prop_id = self._focal_prop_id
+ if prop_id is None:
+ return np.zeros((3,))
+ prop = self._props[prop_id]
+ prop_xmat = physics.bind(
+ mjcf.get_attachment_frame(prop.mjcf_model)).xmat
+ return prop_xmat[[2, 5, 8]]
+ self._task_observables['target_prop/zaxis'] = (
+ observable.Generic(prop_zaxis))
+
+ def ego_pedestal_xpos(physics):
+ pedestal_id = self._focal_pedestal_id
+ if pedestal_id is None:
+ return np.zeros((3,))
+ pedestal = self._pedestals[pedestal_id]
+ pedestal_xpos, _ = pedestal.get_pose(physics)
+ walker_xpos = physics.bind(self._walker.root_body).xpos
+ return self._walker.transform_vec_to_egocentric_frame(
+ physics, pedestal_xpos - walker_xpos)
+ self._task_observables['target_pedestal/xpos'] = (
+ observable.Generic(ego_pedestal_xpos))
+
+ for obs in (self._walker.observables.proprioception +
+ self._walker.observables.kinematic_sensors +
+ self._walker.observables.dynamic_sensors +
+ list(self._task_observables.values())):
+ obs.enabled = True
+
+ self._focal_prop_id = None
+ self._focal_pedestal_id = None
+
+ @property
+ def root_entity(self):
+ return self._arena
+
+ @property
+ def task_observables(self):
+ return self._task_observables
+
+ @property
+ def name(self):
+ return 'warehouse'
+
+ def initialize_episode_mjcf(self, random_state):
+ self._reward = 0.0
+ self._discount = 1.0
+ self._should_terminate = False
+ self._before_step_success = False
+ for target_value in self._target_state.values():
+ target_value[:] = 0
+ for pedestal_id, pedestal in enumerate(self._pedestals):
+ angle = 2 * np.pi * pedestal_id / len(self._pedestals)
+ dist = (_BASE_PEDESTAL_DIST +
+ _PEDESTAL_DIST_DELTA * random_state.uniform(-1, 1))
+
+ height = random_state.uniform(*self._pedestal_height_range)
+ pedestal_pos = [dist * np.cos(angle), dist * np.sin(angle),
+ height - pedestal.geom.size[2]]
+ mjcf.get_attachment_frame(pedestal.mjcf_model).pos = pedestal_pos
+
+ for prop in self._props:
+ prop.detach()
+ self._props = []
+ self._trajectory_for_prop = []
+
+ for prop_id in range(self._num_props):
+ trajectory = random_state.choice(self._trajectories)
+ if self._proto_modifier:
+ trajectory = trajectory.get_modified_trajectory(
+ self._proto_modifier, random_state=random_state)
+ prop = trajectory.create_props(
+ priority_friction=self._priority_friction)[0]
+ prop.mjcf_model.model = 'prop_{}'.format(prop_id)
+ self._arena.add_free_entity(prop)
+ self._props.append(prop)
+ self._trajectory_for_prop.append(trajectory)
+
+ def _settle_props(self, physics):
+ prop_freejoints = [mjcf.get_attachment_frame(prop.mjcf_model).freejoint
+ for prop in self._props]
+ physics.bind(prop_freejoints).qvel = 0
+ physics.forward()
+ for _ in range(_MAX_SETTLE_STEPS):
+ self._update_current_state(physics)
+ success = self._evaluate_target_state()
+ stopped = max(abs(physics.bind(prop_freejoints).qvel)) < _SETTLE_QVEL_TOL
+ if success and stopped:
+ break
+ else:
+ physics.step()
+ physics.data.time = 0
+
+ def initialize_episode(self, physics, random_state):
+ self._ground_geomid = physics.bind(
+ self._arena.mjcf_model.worldbody.geom[0]).element_id
+ self._feet_geomids = set(physics.bind(self._feet_geoms).element_id)
+ self._lhand_geomids = set(physics.bind(self._lhand_geoms).element_id)
+ self._rhand_geomids = set(physics.bind(self._rhand_geoms).element_id)
+
+ for prop_id in range(len(self._props)):
+ pedestal_id = _find_random_free_pedestal_id(
+ self._target_state[PROP_PEDESTAL], random_state)
+ pedestal = self._pedestals[pedestal_id]
+ self._target_state[PROP_PEDESTAL][prop_id, pedestal_id] = 1
+
+ for prop_id, prop in enumerate(self._props):
+ trajectory = self._trajectory_for_prop[prop_id]
+ pedestal_id = np.nonzero(
+ self._target_state[PROP_PEDESTAL][prop_id])[0][0]
+ pedestal = self._pedestals[pedestal_id]
+ pedestal_pos, _ = pedestal.get_pose(physics)
+ pedestal_delta = np.array(
+ pedestal_pos - trajectory.infer_pedestal_positions()[0])
+ pedestal_delta[2] += pedestal.geom.size[2]
+ prop_timestep = trajectory.get_timestep_data(0).props[0]
+ prop_pos = prop_timestep.position + np.array(pedestal_delta)
+ prop_quat = prop_timestep.quaternion
+ prop_pos[:2] += random_state.uniform(
+ -pedestal.geom.size[:2] / 2, pedestal.geom.size[:2] / 2)
+ prop.set_pose(physics, prop_pos, prop_quat)
+ self._settle_props(physics)
+
+ init_params = self._transition_manager.initialize_episode(
+ self._target_state, random_state)
+ if self._log_transitions:
+ logging.info(init_params)
+ self._on_transition(physics)
+
+ init_prop = self._props[init_params.prop_id]
+ init_pedestal = self._pedestals[init_params.pedestal_id]
+ self._init_prop_id = init_params.prop_id
+ self._init_pedestal_id = init_params.pedestal_id
+ init_trajectory = self._trajectory_for_prop[init_params.prop_id]
+ init_timestep = init_trajectory.get_random_timestep_in_segment(
+ init_params.clip_segment, random_state)
+
+ trajectory_pedestal_pos = init_trajectory.infer_pedestal_positions()[0]
+ init_pedestal_pos = np.array(init_pedestal.get_pose(physics)[0])
+ delta_pos = init_pedestal_pos - trajectory_pedestal_pos
+ delta_pos[2] = 0
+ delta_angle = np.pi + np.arctan2(init_pedestal_pos[1], init_pedestal_pos[0])
+ delta_quat = (np.cos(delta_angle / 2), 0, 0, np.sin(delta_angle / 2))
+
+ trajectory_pedestal_to_walker = (
+ init_timestep.walkers[0].position - trajectory_pedestal_pos)
+ rotated_pedestal_to_walker = _rotate_vector_by_quaternion(
+ trajectory_pedestal_to_walker, delta_quat)
+
+ self._walker.set_pose(
+ physics,
+ position=trajectory_pedestal_pos + rotated_pedestal_to_walker,
+ quaternion=init_timestep.walkers[0].quaternion)
+ self._walker.set_velocity(
+ physics, velocity=init_timestep.walkers[0].velocity,
+ angular_velocity=init_timestep.walkers[0].angular_velocity)
+ self._walker.shift_pose(
+ physics, position=delta_pos, quaternion=delta_quat,
+ rotate_velocity=True)
+ physics.bind(self._walker.mocap_joints).qpos = (
+ init_timestep.walkers[0].joints)
+ physics.bind(self._walker.mocap_joints).qvel = (
+ init_timestep.walkers[0].joints_velocity)
+
+ if init_params.clip_segment in (trajectories.ClipSegment.CARRY1,
+ trajectories.ClipSegment.CARRY2,
+ trajectories.ClipSegment.PUTDOWN):
+ trajectory_pedestal_to_prop = (
+ init_timestep.props[0].position - trajectory_pedestal_pos)
+ rotated_pedestal_to_prop = _rotate_vector_by_quaternion(
+ trajectory_pedestal_to_prop, delta_quat)
+ init_prop.set_pose(
+ physics,
+ position=trajectory_pedestal_pos + rotated_pedestal_to_prop,
+ quaternion=init_timestep.props[0].quaternion)
+ init_prop.set_velocity(
+ physics, velocity=init_timestep.props[0].velocity,
+ angular_velocity=init_timestep.props[0].angular_velocity)
+ init_prop.shift_pose(
+ physics, position=delta_pos,
+ quaternion=delta_quat, rotate_velocity=True)
+
+ # If we have moved the pedestal upwards during height initialization,
+ # the prop may now be lodged inside it. We fix that here.
+ if init_pedestal_pos[2] > trajectory_pedestal_pos[2]:
+ init_prop_geomid = physics.bind(init_prop.geom).element_id
+ init_pedestal_geomid = physics.bind(init_pedestal.geom).element_id
+ disallowed_contact = sorted((init_prop_geomid, init_pedestal_geomid))
+ def has_disallowed_contact():
+ physics.forward()
+ for contact in physics.data.contact:
+ if sorted((contact.geom1, contact.geom2)) == disallowed_contact:
+ return True
+ return False
+ while has_disallowed_contact():
+ init_prop.shift_pose(physics, (0, 0, 0.001))
+
+ self._move_arms_if_necessary(physics)
+ self._update_current_state(physics)
+ self._previous_step_success = self._evaluate_target_state()
+
+ self._focal_prop_id = self._init_prop_id
+ self._focal_pedestal_id = self._init_pedestal_id
+
+ def _move_arms_if_necessary(self, physics):
+ if self._min_prop_gap is not None:
+ for entity in self._props + self._pedestals:
+ try:
+ arm_opener.open_arms_for_prop(
+ physics, self._walker.left_arm_root, self._walker.right_arm_root,
+ entity.mjcf_model, self._min_prop_gap)
+ except RuntimeError as e:
+ raise composer.EpisodeInitializationError(e)
+
+ def after_step(self, physics, random_state):
+ # First we check for failure termination.
+ for contact in physics.data.contact:
+ if ((contact.geom1 == self._ground_geomid and
+ contact.geom2 not in self._feet_geomids) or
+ (contact.geom2 == self._ground_geomid and
+ contact.geom1 not in self._feet_geomids)):
+ if self._negative_reward_on_failure_termination:
+ self._reward = -_SPARSE_REWARD
+ else:
+ self._reward = 0.0
+ self._should_terminate = True
+ self._discount = 0.0
+ return
+
+ # Then check for normal reward and state transitions.
+ self._update_current_state(physics)
+ success = self._evaluate_target_state()
+ if success and not self._previous_step_success:
+ self._reward = _SPARSE_REWARD
+ new_phase = (
+ self._transition_manager.on_success(self._target_state, random_state))
+ self._should_terminate = (new_phase == WarehousePhase.TERMINATED)
+ self._on_transition(physics)
+ self._previous_step_success = self._evaluate_target_state()
+ else:
+ self._reward = 0.0
+
+ def _on_transition(self, physics):
+ self._focal_prop_id = self._transition_manager.prop_id
+ self._focal_pedestal_id = self._transition_manager.pedestal_id
+ if self._log_transitions:
+ logging.info('target_state:\n%s', self._target_state)
+ for pedestal_id, pedestal_active in enumerate(
+ self._target_state[WALKER_PEDESTAL]):
+ r, g, b, a = self._pedestal_colors[pedestal_id]
+ if pedestal_active:
+ h, _, s = colorsys.rgb_to_hls(r, g, b)
+ r, g, b = colorsys.hls_to_rgb(h, _ACTIVATED_PEDESTAL_L, s)
+ physics.bind(self._pedestals[pedestal_id].geom).rgba = (r, g, b, a)
+
+ def get_reward(self, physics):
+ return self._reward
+
+ def get_discount(self, physics):
+ return self._discount
+
+ def should_terminate_episode(self, physics):
+ return self._should_terminate
+
+ def _update_current_state(self, physics):
+ for current_state_value in self._current_state.values():
+ current_state_value[:] = 0
+
+ # Check if the walker is near each pedestal.
+ walker_pos, _ = self._walker.get_pose(physics)
+ for pedestal_id, pedestal in enumerate(self._pedestals):
+ target_pos, _ = pedestal.get_pose(physics)
+ walker_to_target_dist = np.linalg.norm(walker_pos[:2] - target_pos[:2])
+ if walker_to_target_dist <= _TARGET_TOL:
+ self._current_state[WALKER_PEDESTAL][pedestal_id] = 1
+
+ prop_geomids = {
+ physics.bind(prop.geom).element_id: prop_id
+ for prop_id, prop in enumerate(self._props)}
+ pedestal_geomids = {
+ physics.bind(pedestal.geom).element_id: pedestal_id
+ for pedestal_id, pedestal in enumerate(self._pedestals)}
+
+ prop_pedestal_contact_counts = np.zeros(
+ [self._num_props, self._num_pedestals])
+ prop_lhand_contact = [False] * self._num_props
+ prop_rhand_contact = [False] * self._num_props
+ for contact in physics.data.contact:
+ prop_id = prop_geomids.get(contact.geom1, prop_geomids.get(contact.geom2))
+ pedestal_id = pedestal_geomids.get(
+ contact.geom1, pedestal_geomids.get(contact.geom2))
+ has_lhand = (contact.geom1 in self._lhand_geomids or
+ contact.geom2 in self._lhand_geomids)
+ has_rhand = (contact.geom1 in self._rhand_geomids or
+ contact.geom2 in self._rhand_geomids)
+ if prop_id is not None and pedestal_id is not None:
+ prop_pedestal_contact_counts[prop_id, pedestal_id] += 1
+ if prop_id is not None and has_lhand:
+ prop_lhand_contact[prop_id] = True
+ if prop_id is not None and has_rhand:
+ prop_rhand_contact[prop_id] = True
+
+ for prop_id in range(self._num_props):
+ if prop_lhand_contact[prop_id] and prop_rhand_contact[prop_id]:
+ self._current_state[WALKER_PROP][prop_id] = 1
+ pedestal_contact_counts = prop_pedestal_contact_counts[prop_id]
+ for pedestal_id in range(self._num_pedestals):
+ if pedestal_contact_counts[pedestal_id] >= 4:
+ self._current_state[PROP_PEDESTAL][prop_id, pedestal_id] = 1
+
+ def _evaluate_target_state(self):
+ return _is_same_state(self._current_state, self._target_state)