diff --git a/README.md b/README.md index add2cd4..ba81b8b 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ https://deepmind.com/research/publications/ ## Projects +* [The Difficulty of Passive Learning in Deep Reinforcement Learning](tandem_dqn), NeurIPS 2021 * [Skilful precipitation nowcasting using deep generative models of radar](nowcasting), Nature 2021 * [Compute-Aided Design as Language](cadl) * [Encoders and ensembles for continual learning](continual_learning) diff --git a/tandem_dqn/README.md b/tandem_dqn/README.md new file mode 100644 index 0000000..50f85b4 --- /dev/null +++ b/tandem_dqn/README.md @@ -0,0 +1,78 @@ +# Tandem DQN + +This repository provides an implementation of the Tandem DQN agent and +main experiments presented in the paper +"The Difficulty of Passive Learning in Deep Reinforcement Learning" +(Georg Ostrovski, Pablo Samuel Castro and Will Dabney, 2021). + +The code is a modified fork of the [DoubleDQN](https://arxiv.org/abs/1509.06461) +agent in the [DQN Zoo](https://github.com/deepmind/dqn_zoo) agent collection. + +## Quick start + +This code can be run with a regular CPU setup, but will only be reasonably fast +(with the given configuration) if run with a GPU accelerator, in which case +you will need an NVIDIA GPU with recent CUDA drivers. +For installation, run + +```shell +git clone git@github.com:deepmind/deepmind-research.git +virtualenv --python=python3.6 "tandem" +source tandem/bin/activate +cd deepmind_research +pip install -r tandem_dqn/requirements.txt +``` + +The code has only been tested with Python 3.6.14 + +## Running the experiments presented in the paper + +To execute the vanilla Tandem DQN experiment on the Pong environment, run: + +```bash +python -m "tandem_dqn.run_tandem" --environment_name=pong --seed=42 +``` + +A number of flags can be specified to customize execution and run various of +the presented experimental variations: + +* `--use_sticky_actions` +(values: `True`, `False`; default: `False`): +Use "sticky actions" variant +([Machado et al 2017](https://arxiv.org/abs/1709.06009)) of the Atari environment. + +* `--network_active`, `--network_passive` +(values: `double_q`, `qr`; default: `double_q`): +Whether to use the DoubleDQN or QR-DQN +([Dabney et al 2017](https://arxiv.org/abs/1710.10044)) network architecture +for active and passive agents (can be set independently). Note this value +needs to be compatible with the chosen loss (see below). + +* `--loss_active`, `--loss_passive` +(values: `double_q`, `double_q_v`, `double_q_p`, `double_q_pv`, `qr`, +`q_regression`; default: `double_q`) +Which loss to use for active and passive agent training. The losses are: + * `double_q`: regular Double-Q-Learning loss. + * `double_q_v`: Double-Q-Learning with bootstrap target _values_ provided + by the respective _other_ agent's target network. + * `double_q_p`: Double-Q-Learning with bootstrap target _policy_ (argmax) + provided by the respective _other_ agent's online network. + * `double_q_pv`: Double-Q-Learning with both boostrap target values and + policy provided by the respective _other_ agent's target and online networks. + * `qr`: Quantile Regression Q-Learning loss + ([Dabney et al 2017](https://arxiv.org/abs/1710.10044)). + * `q_regression`: Supervised regression loss towards the respective other + agent's online network's output values. + +* `--optimizer_active`, `--optimizer_passive` +(values: `adam`, `rmsprop`; default: `rmsprop` for both): +Which optimization algorithm to use for the active and passive network training +(can be set independently). + +* `--exploration_epsilon_end_value` +(values in `[0, 1]`; default: `0.01`): +Value of epsilon parameter in active agent's epsilong-greedy policy after an +initial decay phase. + + + diff --git a/tandem_dqn/agent.py b/tandem_dqn/agent.py new file mode 100644 index 0000000..3018c33 --- /dev/null +++ b/tandem_dqn/agent.py @@ -0,0 +1,378 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tandem DQN agent class.""" + +import typing +from typing import Any, Callable, Mapping, Set, Text + +from absl import logging +import dm_env +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np +import optax +import rlax + +from tandem_dqn import losses +from tandem_dqn import parts +from tandem_dqn import processors +from tandem_dqn import replay as replay_lib + + +class TandemTuple(typing.NamedTuple): + active: Any + passive: Any + + +def tandem_map(fn: Callable[..., Any], *args): + return TandemTuple( + active=fn(*[a.active for a in args]), + passive=fn(*[a.passive for a in args])) + + +def replace_module_params(source, target, modules): + """Replace selected module params in target by corresponding source values.""" + source, _ = hk.data_structures.partition( + lambda module, name, value: module in modules, + source) + return hk.data_structures.merge(target, source) + + +class TandemDqn(parts.Agent): + """Tandem DQN agent.""" + + def __init__( + self, + preprocessor: processors.Processor, + sample_network_input: jnp.ndarray, + network: TandemTuple, + optimizer: TandemTuple, + loss: TandemTuple, + transition_accumulator: Any, + replay: replay_lib.TransitionReplay, + batch_size: int, + exploration_epsilon: Callable[[int], float], + min_replay_capacity_fraction: float, + learn_period: int, + target_network_update_period: int, + tied_layers: Set[str], + rng_key: parts.PRNGKey, + ): + self._preprocessor = preprocessor + self._replay = replay + self._transition_accumulator = transition_accumulator + self._batch_size = batch_size + self._exploration_epsilon = exploration_epsilon + self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity + self._learn_period = learn_period + self._target_network_update_period = target_network_update_period + + # Initialize network parameters and optimizer. + self._rng_key, network_rng_key_active, network_rng_key_passive = ( + jax.random.split(rng_key, 3)) + active_params = network.active.init( + network_rng_key_active, sample_network_input[None, ...]) + passive_params = network.passive.init( + network_rng_key_passive, sample_network_input[None, ...]) + self._online_params = TandemTuple( + active=active_params, passive=passive_params) + self._target_params = self._online_params + self._opt_state = tandem_map( + lambda optim, params: optim.init(params), + optimizer, self._online_params) + + # Other agent state: last action, frame count, etc. + self._action = None + self._frame_t = -1 # Current frame index. + + # Stats. + stats = [ + 'loss_active', + 'loss_passive', + 'frac_diff_argmax', + 'mc_error_active', + 'mc_error_passive', + 'mc_error_abs_active', + 'mc_error_abs_passive', + ] + self._statistics = {k: np.nan for k in stats} + + # Define jitted loss, update, and policy functions here instead of as + # class methods, to emphasize that these are meant to be pure functions + # and should not access the agent object's state via `self`. + + def network_outputs(rng_key, online_params, target_params, transitions): + """Compute all potentially needed outputs of active and passive net.""" + _, *apply_keys = jax.random.split(rng_key, 4) + outputs_tm1 = tandem_map( + lambda net, param: net.apply(param, apply_keys[0], transitions.s_tm1), + network, online_params) + outputs_t = tandem_map( + lambda net, param: net.apply(param, apply_keys[1], transitions.s_t), + network, online_params) + outputs_target_t = tandem_map( + lambda net, param: net.apply(param, apply_keys[2], transitions.s_t), + network, target_params) + return outputs_tm1, outputs_t, outputs_target_t + + # Helper functions to define active and passive losses. + # Active and passive losses are allowed to depend on all active and passive + # outputs, but stop-gradient is used to prevent gradients from flowing + # from active loss to passive network params and vice versa. + def sg_active(x): + return TandemTuple( + active=jax.lax.stop_gradient(x.active), passive=x.passive) + + def sg_passive(x): + return TandemTuple( + active=x.active, passive=jax.lax.stop_gradient(x.passive)) + + def compute_loss(online_params, target_params, transitions, rng_key): + rng_key, apply_key = jax.random.split(rng_key) + outputs_tm1, outputs_t, outputs_target_t = network_outputs( + apply_key, online_params, target_params, transitions) + + _, loss_key_active, loss_key_passive = jax.random.split(rng_key, 3) + loss_active = loss.active( + sg_passive(outputs_tm1), sg_passive(outputs_t), outputs_target_t, + transitions, loss_key_active) + loss_passive = loss.passive( + sg_active(outputs_tm1), sg_active(outputs_t), outputs_target_t, + transitions, loss_key_passive) + + # Logging stuff. + a_tm1 = transitions.a_tm1 + mc_return_tm1 = transitions.mc_return_tm1 + q_values = TandemTuple( + active=outputs_tm1.active.q_values, + passive=outputs_tm1.passive.q_values) + mc_error = jax.tree_map( + lambda q: losses.batch_mc_learning(q, a_tm1, mc_return_tm1), + q_values) + mc_error_abs = jax.tree_map(jnp.abs, mc_error) + q_argmax = jax.tree_map(lambda q: jnp.argmax(q, axis=-1), q_values) + argmax_diff = jnp.not_equal(q_argmax.active, q_argmax.passive) + + batch_mean = lambda x: jnp.mean(x, axis=0) + logs = { + 'loss_active': loss_active, + 'loss_passive': loss_passive + } + logs.update(jax.tree_map(batch_mean, { + 'frac_diff_argmax': argmax_diff, + 'mc_error_active': mc_error.active, + 'mc_error_passive': mc_error.passive, + 'mc_error_abs_active': mc_error_abs.active, + 'mc_error_abs_passive': mc_error_abs.passive, + })) + return loss_active + loss_passive, logs + + def optim_update(optim, online_params, d_loss_d_params, opt_state): + updates, new_opt_state = optim.update(d_loss_d_params, opt_state) + new_online_params = optax.apply_updates(online_params, updates) + return new_opt_state, new_online_params + + def compute_loss_grad(rng_key, online_params, target_params, transitions): + rng_key, grad_key = jax.random.split(rng_key) + (_, logs), d_loss_d_params = jax.value_and_grad( + compute_loss, has_aux=True)( + online_params, target_params, transitions, grad_key) + return rng_key, logs, d_loss_d_params + + def update_active(rng_key, opt_state, online_params, target_params, + transitions): + """Applies learning update for active network only.""" + rng_key, logs, d_loss_d_params = compute_loss_grad( + rng_key, online_params, target_params, transitions) + new_opt_state_active, new_online_params_active = optim_update( + optimizer.active, online_params.active, d_loss_d_params.active, + opt_state.active) + new_opt_state = opt_state._replace( + active=new_opt_state_active) + new_online_params = online_params._replace( + active=new_online_params_active) + return rng_key, new_opt_state, new_online_params, logs + + self._update_active = jax.jit(update_active) + + def update_passive(rng_key, opt_state, online_params, target_params, + transitions): + """Applies learning update for passive network only.""" + rng_key, logs, d_loss_d_params = compute_loss_grad( + rng_key, online_params, target_params, transitions) + new_opt_state_passive, new_online_params_passive = optim_update( + optimizer.passive, online_params.passive, d_loss_d_params.passive, + opt_state.passive) + new_opt_state = opt_state._replace( + passive=new_opt_state_passive) + new_online_params = online_params._replace( + passive=new_online_params_passive) + return rng_key, new_opt_state, new_online_params, logs + + self._update_passive = jax.jit(update_passive) + + def update_active_passive(rng_key, opt_state, online_params, + target_params, transitions): + """Applies learning update for both active & passive networks.""" + rng_key, logs, d_loss_d_params = compute_loss_grad( + rng_key, online_params, target_params, transitions) + + new_opt_state_active, new_online_params_active = optim_update( + optimizer.active, online_params.active, d_loss_d_params.active, + opt_state.active) + new_opt_state_passive, new_online_params_passive = optim_update( + optimizer.passive, online_params.passive, d_loss_d_params.passive, + opt_state.passive) + new_opt_state = TandemTuple(active=new_opt_state_active, + passive=new_opt_state_passive) + new_online_params = TandemTuple(active=new_online_params_active, + passive=new_online_params_passive) + return rng_key, new_opt_state, new_online_params, logs + + self._update_active_passive = jax.jit(update_active_passive) + + self._update = None # set_training_mode needs to be called to set this. + + def sync_tied_layers(online_params): + """Set tied layer params of passive to respective values of active.""" + new_online_params_passive = replace_module_params( + source=online_params.active, target=online_params.passive, + modules=tied_layers) + return online_params._replace(passive=new_online_params_passive) + + self._sync_tied_layers = jax.jit(sync_tied_layers) + + def select_action(rng_key, network_params, s_t, exploration_epsilon): + """Samples action from eps-greedy policy wrt Q-values at given state.""" + rng_key, apply_key, policy_key = jax.random.split(rng_key, 3) + q_t = network.active.apply(network_params, apply_key, + s_t[None, ...]).q_values[0] + a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon) + return rng_key, a_t + + self._select_action = jax.jit(select_action) + + def step(self, timestep: dm_env.TimeStep) -> parts.Action: + """Selects action given timestep and potentially learns.""" + self._frame_t += 1 + + timestep = self._preprocessor(timestep) + + if timestep is None: # Repeat action. + action = self._action + else: + action = self._action = self._act(timestep) + + for transition in self._transition_accumulator.step(timestep, action): + self._replay.add(transition) + + if self._replay.size < self._min_replay_capacity: + return action + + if self._frame_t % self._learn_period == 0: + self._learn() + + if self._frame_t % self._target_network_update_period == 0: + self._target_params = self._online_params + + return action + + def reset(self) -> None: + """Resets the agent's episodic state such as frame stack and action repeat. + + This method should be called at the beginning of every episode. + """ + self._transition_accumulator.reset() + processors.reset(self._preprocessor) + self._action = None + + def _act(self, timestep) -> parts.Action: + """Selects action given timestep, according to epsilon-greedy policy.""" + s_t = timestep.observation + network_params = self._online_params.active + self._rng_key, a_t = self._select_action( + self._rng_key, network_params, s_t, self.exploration_epsilon) + return parts.Action(jax.device_get(a_t)) + + def _learn(self) -> None: + """Samples a batch of transitions from replay and learns from it.""" + logging.log_first_n(logging.INFO, 'Begin learning', 1) + transitions = self._replay.sample(self._batch_size) + self._rng_key, self._opt_state, self._online_params, logs = self._update( + self._rng_key, + self._opt_state, + self._online_params, + self._target_params, + transitions, + ) + self._online_params = self._sync_tied_layers(self._online_params) + self._statistics.update(jax.device_get(logs)) + + def set_training_mode(self, mode: str): + """Sets training mode to one of 'active', 'passive', or 'active_passive'.""" + if mode == 'active': + self._update = self._update_active + elif mode == 'passive': + self._update = self._update_passive + elif mode == 'active_passive': + self._update = self._update_active_passive + + @property + def online_params(self) -> TandemTuple: + """Returns current parameters of Q-network.""" + return self._online_params + + @property + def statistics(self) -> Mapping[Text, float]: + """Returns current agent statistics as a dictionary.""" + # Check for DeviceArrays in values as this can be very slow. + assert all( + not isinstance(x, jnp.DeviceArray) for x in self._statistics.values()) + return self._statistics + + @property + def exploration_epsilon(self) -> float: + """Returns epsilon value currently used by (eps-greedy) behavior policy.""" + return self._exploration_epsilon(self._frame_t) + + def get_state(self) -> Mapping[Text, Any]: + """Retrieves agent state as a dictionary (e.g. for serialization).""" + state = { + 'rng_key': self._rng_key, + 'frame_t': self._frame_t, + 'opt_state_active': self._opt_state.active, + 'online_params_active': self._online_params.active, + 'target_params_active': self._target_params.active, + 'opt_state_passive': self._opt_state.passive, + 'online_params_passive': self._online_params.passive, + 'target_params_passive': self._target_params.passive, + 'replay': self._replay.get_state(), + } + return state + + def set_state(self, state: Mapping[Text, Any]) -> None: + """Sets agent state from a (potentially de-serialized) dictionary.""" + self._rng_key = state['rng_key'] + self._frame_t = state['frame_t'] + self._opt_state = TandemTuple( + active=jax.device_put(state['opt_state_active']), + passive=jax.device_put(state['opt_state_passive'])) + self._online_params = TandemTuple( + active=jax.device_put(state['online_params_active']), + passive=jax.device_put(state['online_params_passive'])) + self._target_params = TandemTuple( + active=jax.device_put(state['target_params_active']), + passive=jax.device_put(state['target_params_passive'])) + self._replay.set_state(state['replay']) diff --git a/tandem_dqn/atari_data.py b/tandem_dqn/atari_data.py new file mode 100644 index 0000000..e42b428 --- /dev/null +++ b/tandem_dqn/atari_data.py @@ -0,0 +1,110 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utilities to compute human-normalized Atari scores. + +The data used in this module is human and random performance data on Atari-57. +It comprises of evaluation scores (undiscounted returns), each averaged +over at least 3 episode runs, on each of the 57 Atari games. Each episode begins +with the environment already stepped with a uniform random number (between 1 and +30 inclusive) of noop actions. + +The two agents are: +* 'random' (agent choosing its actions uniformly randomly on each step) +* 'human' (professional human game tester) + +Scores are obtained by averaging returns over the episodes played by each agent, +with episode length capped to 108,000 frames (i.e. timeout after 30 minutes). + +The term 'human-normalized' here means a linear per-game transformation of +a game score in such a way that 0 corresponds to random performance and 1 +corresponds to human performance. +""" + +import math + +# Game: score-tuple dictionary. Each score tuple contains +# 0: score random (float) and 1: score human (float). +_ATARI_DATA = { + 'alien': (227.8, 7127.7), + 'amidar': (5.8, 1719.5), + 'assault': (222.4, 742.0), + 'asterix': (210.0, 8503.3), + 'asteroids': (719.1, 47388.7), + 'atlantis': (12850.0, 29028.1), + 'bank_heist': (14.2, 753.1), + 'battle_zone': (2360.0, 37187.5), + 'beam_rider': (363.9, 16926.5), + 'berzerk': (123.7, 2630.4), + 'bowling': (23.1, 160.7), + 'boxing': (0.1, 12.1), + 'breakout': (1.7, 30.5), + 'centipede': (2090.9, 12017.0), + 'chopper_command': (811.0, 7387.8), + 'crazy_climber': (10780.5, 35829.4), + 'defender': (2874.5, 18688.9), + 'demon_attack': (152.1, 1971.0), + 'double_dunk': (-18.6, -16.4), + 'enduro': (0.0, 860.5), + 'fishing_derby': (-91.7, -38.7), + 'freeway': (0.0, 29.6), + 'frostbite': (65.2, 4334.7), + 'gopher': (257.6, 2412.5), + 'gravitar': (173.0, 3351.4), + 'hero': (1027.0, 30826.4), + 'ice_hockey': (-11.2, 0.9), + 'jamesbond': (29.0, 302.8), + 'kangaroo': (52.0, 3035.0), + 'krull': (1598.0, 2665.5), + 'kung_fu_master': (258.5, 22736.3), + 'montezuma_revenge': (0.0, 4753.3), + 'ms_pacman': (307.3, 6951.6), + 'name_this_game': (2292.3, 8049.0), + 'phoenix': (761.4, 7242.6), + 'pitfall': (-229.4, 6463.7), + 'pong': (-20.7, 14.6), + 'private_eye': (24.9, 69571.3), + 'qbert': (163.9, 13455.0), + 'riverraid': (1338.5, 17118.0), + 'road_runner': (11.5, 7845.0), + 'robotank': (2.2, 11.9), + 'seaquest': (68.4, 42054.7), + 'skiing': (-17098.1, -4336.9), + 'solaris': (1236.3, 12326.7), + 'space_invaders': (148.0, 1668.7), + 'star_gunner': (664.0, 10250.0), + 'surround': (-10.0, 6.5), + 'tennis': (-23.8, -8.3), + 'time_pilot': (3568.0, 5229.2), + 'tutankham': (11.4, 167.6), + 'up_n_down': (533.4, 11693.2), + 'venture': (0.0, 1187.5), + # Note the random agent score on Video Pinball is sometimes greater than the + # human score under other evaluation methods. + 'video_pinball': (16256.9, 17667.9), + 'wizard_of_wor': (563.5, 4756.5), + 'yars_revenge': (3092.9, 54576.9), + 'zaxxon': (32.5, 9173.3), +} + +_RANDOM_COL = 0 +_HUMAN_COL = 1 + +ATARI_GAMES = tuple(sorted(_ATARI_DATA.keys())) + + +def get_human_normalized_score(game: str, raw_score: float) -> float: + """Converts game score to human-normalized score.""" + game_scores = _ATARI_DATA.get(game, (math.nan, math.nan)) + random, human = game_scores[_RANDOM_COL], game_scores[_HUMAN_COL] + return (raw_score - random) / (human - random) diff --git a/tandem_dqn/gym_atari.py b/tandem_dqn/gym_atari.py new file mode 100644 index 0000000..3171571 --- /dev/null +++ b/tandem_dqn/gym_atari.py @@ -0,0 +1,217 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""dm_env environment wrapper around Gym Atari configured to be like Xitari. + +Gym Atari is built on the Arcade Learning Environment (ALE), whereas Xitari is +an old fork of the ALE. +""" + +# pylint: disable=g-bad-import-order + +from typing import Optional, Tuple + +import atari_py # pylint: disable=unused-import for gym to load Atari games. +import dm_env +from dm_env import specs +import gym +import numpy as np + +from tandem_dqn import atari_data + +_GYM_ID_SUFFIX = '-xitari-v1' +_SA_SUFFIX = '-sa' + + +def _game_id(game, sticky_actions): + return game + (_SA_SUFFIX if sticky_actions else '') + _GYM_ID_SUFFIX + + +def _register_atari_environments(): + """Registers Atari environments in Gym to be as similar to Xitari as possible. + + Main difference from PongNoFrameSkip-v4, etc. is max_episode_steps is unset + and only the usual 57 Atari games are registered. + + Additionally, sticky-actions variants of the environments are registered + with an '-sa' suffix. + """ + for sticky_actions in [False, True]: + for game in atari_data.ATARI_GAMES: + repeat_action_probability = 0.25 if sticky_actions else 0.0 + gym.envs.registration.register( + id=_game_id(game, sticky_actions), + entry_point='gym.envs.atari:AtariEnv', + kwargs={ # Explicitly set all known arguments. + 'game': game, + 'mode': None, # Not necessarily the same as 0. + 'difficulty': None, # Not necessarily the same as 0. + 'obs_type': 'image', + 'frameskip': 1, # Get every frame. + 'repeat_action_probability': repeat_action_probability, + 'full_action_space': False, + }, + max_episode_steps=None, # No time limit, handled in run loop. + nondeterministic=False, # Xitari is deterministic. + ) + + +_register_atari_environments() + + +class GymAtari(dm_env.Environment): + """Gym Atari with a `dm_env.Environment` interface.""" + + def __init__(self, game, sticky_actions, seed): + self._gym_env = gym.make(_game_id(game, sticky_actions)) + self._gym_env.seed(seed) + self._start_of_episode = True + + def reset(self) -> dm_env.TimeStep: + """Resets the environment and starts a new episode.""" + observation = self._gym_env.reset() + lives = np.int32(self._gym_env.ale.lives()) + timestep = dm_env.restart((observation, lives)) + self._start_of_episode = False + return timestep + + def step(self, action: np.int32) -> dm_env.TimeStep: + """Updates the environment given an action and returns a timestep.""" + # If the previous timestep was LAST then we call reset() on the Gym + # environment, otherwise step(). Although Gym environments allow you to step + # through episode boundaries (similar to dm_env) they emit a warning. + if self._start_of_episode: + step_type = dm_env.StepType.FIRST + observation = self._gym_env.reset() + discount = None + reward = None + done = False + else: + observation, reward, done, info = self._gym_env.step(action) + if done: + assert 'TimeLimit.truncated' not in info, 'Should never truncate.' + step_type = dm_env.StepType.LAST + discount = 0. + else: + step_type = dm_env.StepType.MID + discount = 1. + + lives = np.int32(self._gym_env.ale.lives()) + timestep = dm_env.TimeStep( + step_type=step_type, + observation=(observation, lives), + reward=reward, + discount=discount, + ) + self._start_of_episode = done + return timestep + + def observation_spec(self) -> Tuple[specs.Array, specs.Array]: + space = self._gym_env.observation_space + return (specs.Array(shape=space.shape, dtype=space.dtype, name='rgb'), + specs.Array(shape=(), dtype=np.int32, name='lives')) + + def action_spec(self) -> specs.DiscreteArray: + space = self._gym_env.action_space + return specs.DiscreteArray( + num_values=space.n, dtype=np.int32, name='action') + + def close(self): + self._gym_env.close() + + +class RandomNoopsEnvironmentWrapper(dm_env.Environment): + """Adds a random number of noop actions at the beginning of each episode.""" + + def __init__(self, + environment: dm_env.Environment, + max_noop_steps: int, + min_noop_steps: int = 0, + noop_action: int = 0, + seed: Optional[int] = None): + """Initializes the random noops environment wrapper.""" + self._environment = environment + if max_noop_steps < min_noop_steps: + raise ValueError('max_noop_steps must be greater or equal min_noop_steps') + self._min_noop_steps = min_noop_steps + self._max_noop_steps = max_noop_steps + self._noop_action = noop_action + self._rng = np.random.RandomState(seed) + + def reset(self): + """Begins new episode. + + This method resets the wrapped environment and applies a random number + of noop actions before returning the last resulting observation + as the first episode timestep. Intermediate timesteps emitted by the inner + environment (including all rewards and discounts) are discarded. + + Returns: + First episode timestep corresponding to the timestep after a random number + of noop actions are applied to the inner environment. + + Raises: + RuntimeError: if an episode end occurs while the inner environment + is being stepped through with the noop action. + """ + return self._apply_random_noops(initial_timestep=self._environment.reset()) + + def step(self, action): + """Steps environment given action. + + If beginning a new episode then random noops are applied as in `reset()`. + + Args: + action: action to pass to environment conforming to action spec. + + Returns: + `Timestep` from the inner environment unless beginning a new episode, in + which case this is the timestep after a random number of noop actions + are applied to the inner environment. + """ + timestep = self._environment.step(action) + if timestep.first(): + return self._apply_random_noops(initial_timestep=timestep) + else: + return timestep + + def _apply_random_noops(self, initial_timestep): + assert initial_timestep.first() + num_steps = self._rng.randint(self._min_noop_steps, + self._max_noop_steps + 1) + timestep = initial_timestep + for _ in range(num_steps): + timestep = self._environment.step(self._noop_action) + if timestep.last(): + raise RuntimeError('Episode ended while applying %s noop actions.' % + num_steps) + + # We make sure to return a FIRST timestep, i.e. discard rewards & discounts. + return dm_env.restart(timestep.observation) + + ## All methods except for reset and step redirect to the underlying env. + + def observation_spec(self): + return self._environment.observation_spec() + + def action_spec(self): + return self._environment.action_spec() + + def reward_spec(self): + return self._environment.reward_spec() + + def discount_spec(self): + return self._environment.discount_spec() + + def close(self): + return self._environment.close() diff --git a/tandem_dqn/losses.py b/tandem_dqn/losses.py new file mode 100644 index 0000000..1d2ba21 --- /dev/null +++ b/tandem_dqn/losses.py @@ -0,0 +1,190 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Losses for TandemDQN.""" + +from typing import Any, Callable + +import chex +import jax +import jax.numpy as jnp +import rlax + +from tandem_dqn import networks + +# Batch variants of double_q_learning and SARSA. +batch_double_q_learning = jax.vmap(rlax.double_q_learning) +batch_sarsa_learning = jax.vmap(rlax.sarsa) + +# Batch variant of quantile_q_learning with fixed tau input across batch. +batch_quantile_q_learning = jax.vmap( + rlax.quantile_q_learning, in_axes=(0, None, 0, 0, 0, 0, 0, None)) + + +def _mc_learning( + q_tm1: chex.Array, + a_tm1: chex.Numeric, + mc_return_tm1: chex.Array, +) -> chex.Numeric: + """Calculates the MC return error.""" + chex.assert_rank([q_tm1, a_tm1], [1, 0]) + chex.assert_type([q_tm1, a_tm1], [float, int]) + return mc_return_tm1 - q_tm1[a_tm1] + +# Batch variant of MC learning. +batch_mc_learning = jax.vmap(_mc_learning) + + +def _qr_loss(q_tm1, q_t, q_target_t, transitions, rng_key): + """Calculates QR-Learning loss from network outputs and transitions.""" + del q_t, rng_key # Unused. + # Compute Q value distributions. + huber_param = 1. + quantiles = networks.make_quantiles() + losses = batch_quantile_q_learning( + q_tm1.q_dist, + quantiles, + transitions.a_tm1, + transitions.r_t, + transitions.discount_t, + q_target_t.q_dist, # No double Q-learning here. + q_target_t.q_dist, + huber_param, + ) + loss = jnp.mean(losses) + return loss + + +def _sarsa_loss(q_tm1, q_t, transitions, rng_key): + """Calculates SARSA loss from network outputs and transitions.""" + del rng_key # Unused. + grad_error_bound = 1. / 32 + td_errors = batch_sarsa_learning( + q_tm1.q_values, + transitions.a_tm1, + transitions.r_t, + transitions.discount_t, + q_t.q_values, + transitions.a_t + ) + td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound) + losses = rlax.l2_loss(td_errors) + loss = jnp.mean(losses) + return loss + + +def _mc_loss(q_tm1, transitions, rng_key): + """Calculates Monte-Carlo return loss, i.e. regression towards MC return.""" + del rng_key # Unused. + errors = batch_mc_learning(q_tm1.q_values, transitions.a_tm1, + transitions.mc_return_tm1) + loss = jnp.mean(rlax.l2_loss(errors)) + return loss + + +def _double_q_loss(q_tm1, q_t, q_target_t, transitions, rng_key): + """Calculates Double Q-Learning loss from network outputs and transitions.""" + del rng_key # Unused. + grad_error_bound = 1. / 32 + td_errors = batch_double_q_learning( + q_tm1.q_values, + transitions.a_tm1, + transitions.r_t, + transitions.discount_t, + q_target_t.q_values, + q_t.q_values, + ) + td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound) + losses = rlax.l2_loss(td_errors) + loss = jnp.mean(losses) + return loss + + +def _q_regression_loss(q_tm1, q_tm1_target): + """Loss for regression of all action values towards targets.""" + errors = q_tm1.q_values - jax.lax.stop_gradient(q_tm1_target.q_values) + loss = jnp.mean(rlax.l2_loss(errors)) + return loss + + +def make_loss_fn(loss_type: str, active: bool) -> Callable[..., Any]: + """Create active or passive loss function of given type.""" + + if active: + primary = lambda x: x.active + secondary = lambda x: x.passive + else: + primary = lambda x: x.passive + secondary = lambda x: x.active + + def sarsa_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key): + """SARSA loss using own networks.""" + del q_t # Unused. + return _sarsa_loss(primary(q_tm1), primary(q_target_t), transitions, + rng_key) + + def mc_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key): + """MonteCarlo loss.""" + del q_t, q_target_t + return _mc_loss(primary(q_tm1), transitions, rng_key) + + def double_q_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key): + """Regular DoubleQ loss using own networks.""" + return _double_q_loss(primary(q_tm1), primary(q_t), primary(q_target_t), + transitions, rng_key) + + def double_q_loss_v_fn(q_tm1, q_t, q_target_t, transitions, rng_key): + """DoubleQ loss using other network's (target) value function.""" + return _double_q_loss(primary(q_tm1), primary(q_t), secondary(q_target_t), + transitions, rng_key) + + def double_q_loss_p_fn(q_tm1, q_t, q_target_t, transitions, rng_key): + """DoubleQ loss using other network's (online) argmax policy.""" + return _double_q_loss(primary(q_tm1), secondary(q_t), primary(q_target_t), + transitions, rng_key) + + def double_q_loss_pv_fn(q_tm1, q_t, q_target_t, transitions, rng_key): + """DoubleQ loss using other network's argmax policy & target value fn.""" + return _double_q_loss(primary(q_tm1), secondary(q_t), secondary(q_target_t), + transitions, rng_key) + + # Pure regression. + def q_regression_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key): + """Pure regression of q_tm1(self) towards q_tm1(other).""" + del q_t, q_target_t, transitions, rng_key # Unused. + return _q_regression_loss(primary(q_tm1), secondary(q_tm1)) + + # QR loss. + def qr_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key): + """QR-Q loss using own networks.""" + return _qr_loss(primary(q_tm1), primary(q_t), primary(q_target_t), + transitions, rng_key) + + if loss_type == 'double_q': + return double_q_loss_fn + elif loss_type == 'sarsa': + return sarsa_loss_fn + elif loss_type == 'mc_return': + return mc_loss_fn + elif loss_type == 'double_q_v': + return double_q_loss_v_fn + elif loss_type == 'double_q_p': + return double_q_loss_p_fn + elif loss_type == 'double_q_pv': + return double_q_loss_pv_fn + elif loss_type == 'q_regression': + return q_regression_loss_fn + elif loss_type == 'qr': + return qr_loss_fn + else: + raise ValueError('Unknown loss "{}"'.format(loss_type)) diff --git a/tandem_dqn/losses_test.py b/tandem_dqn/losses_test.py new file mode 100644 index 0000000..265b979 --- /dev/null +++ b/tandem_dqn/losses_test.py @@ -0,0 +1,95 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for Tandem losses.""" + +from absl.testing import absltest +from absl.testing import parameterized +import chex +import jax +from jax.config import config +import numpy as np + +from tandem_dqn import agent +from tandem_dqn import losses +from tandem_dqn import networks +from tandem_dqn import replay + + +def make_tandem_qvals(): + return agent.TandemTuple( + active=networks.QNetworkOutputs(3. * np.ones((3, 5), np.float32)), + passive=networks.QNetworkOutputs(2. * np.ones((3, 5), np.float32)) + ) + + +def make_transition(): + return replay.Transition( + s_tm1=np.zeros(3), a_tm1=np.ones(3, np.int32), r_t=5. * np.ones(3), + discount_t=0.9 * np.ones(3), s_t=np.zeros(3)) + + +class DoubleQLossesTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.qs = make_tandem_qvals() + self.transition = make_transition() + self.rng_key = jax.random.PRNGKey(42) + + @chex.all_variants() + @parameterized.parameters( + ('double_q',), ('double_q_v',), ('double_q_p',), ('double_q_pv',), + ('q_regression',), + ) + def test_active_loss_gradients(self, loss_type): + loss_fn = losses.make_loss_fn(loss_type, active=True) + def fn(q_tm1, q_t, q_t_target, transition, rng_key): + return loss_fn(q_tm1, q_t, q_t_target, transition, rng_key) + grad_fn = self.variant(jax.grad(fn, argnums=(0, 1, 2))) + + dldq_tm1, dldq_t, dldq_t_target = grad_fn( + self.qs, self.qs, self.qs, self.transition, self.rng_key) + # Assert that only active net gets nonzero gradients. + self.assertGreater(np.sum(np.abs(dldq_tm1.active.q_values)), 0.) + self.assertTrue(np.all(dldq_t.active.q_values == 0.)) + self.assertTrue(np.all(dldq_t_target.active.q_values == 0.)) + self.assertTrue(np.all(dldq_t.passive.q_values == 0.)) + self.assertTrue(np.all(dldq_tm1.passive.q_values == 0.)) + self.assertTrue(np.all(dldq_t_target.passive.q_values == 0.)) + + @chex.all_variants() + @parameterized.parameters( + ('double_q',), ('double_q_v',), ('double_q_p',), ('double_q_pv',), + ('q_regression',), + ) + def test_passive_loss_gradients(self, loss_type): + loss_fn = losses.make_loss_fn(loss_type, active=False) + def fn(q_tm1, q_t, q_t_target, transition, rng_key): + return loss_fn(q_tm1, q_t, q_t_target, transition, rng_key) + grad_fn = self.variant(jax.grad(fn, argnums=(0, 1, 2))) + + dldq_tm1, dldq_t, dldq_t_target = grad_fn( + self.qs, self.qs, self.qs, self.transition, self.rng_key) + # Assert that only passive net gets nonzero gradients. + self.assertGreater(np.sum(np.abs(dldq_tm1.passive.q_values)), 0.) + self.assertTrue(np.all(dldq_t.passive.q_values == 0.)) + self.assertTrue(np.all(dldq_t_target.passive.q_values == 0.)) + self.assertTrue(np.all(dldq_t.active.q_values == 0.)) + self.assertTrue(np.all(dldq_tm1.active.q_values == 0.)) + self.assertTrue(np.all(dldq_t_target.active.q_values == 0.)) + + +if __name__ == '__main__': + config.update('jax_numpy_rank_promotion', 'raise') + absltest.main() diff --git a/tandem_dqn/networks.py b/tandem_dqn/networks.py new file mode 100644 index 0000000..1ab29a5 --- /dev/null +++ b/tandem_dqn/networks.py @@ -0,0 +1,216 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DQN agent network components and implementation.""" + +import typing +from typing import Any, Callable, Tuple, Union + +import chex +import haiku as hk +import jax +import jax.numpy as jnp +import numpy as np + +Network = hk.Transformed +Params = hk.Params +NetworkFn = Callable[..., Any] + + +class QNetworkOutputs(typing.NamedTuple): + q_values: jnp.ndarray + + +class QRNetworkOutputs(typing.NamedTuple): + q_values: jnp.ndarray + q_dist: jnp.ndarray + + +NUM_QUANTILES = 201 + + +def _dqn_default_initializer( + num_input_units: int) -> hk.initializers.Initializer: + """Default initialization scheme inherited from past implementations of DQN. + + This scheme was historically used to initialize all weights and biases + in convolutional and linear layers of DQN-type agents' networks. + It initializes each weight as an independent uniform sample from [`-c`, `c`], + where `c = 1 / np.sqrt(num_input_units)`, and `num_input_units` is the number + of input units affecting a single output unit in the given layer, i.e. the + total number of inputs in the case of linear (dense) layers, and + `num_input_channels * kernel_width * kernel_height` in the case of + convolutional layers. + + Args: + num_input_units: number of input units to a single output unit of the layer. + + Returns: + Haiku weight initializer. + """ + max_val = np.sqrt(1 / num_input_units) + return hk.initializers.RandomUniform(-max_val, max_val) + + +def make_quantiles(): + """Quantiles for QR-DQN.""" + return (jnp.arange(0, NUM_QUANTILES) + 0.5) / float(NUM_QUANTILES) + + +def conv( + num_features: int, + kernel_shape: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]], + name=None, +) -> NetworkFn: + """Convolutional layer with DQN's legacy weight initialization scheme.""" + + def net_fn(inputs): + """Function representing conv layer with DQN's legacy initialization.""" + num_input_units = inputs.shape[-1] * kernel_shape[0] * kernel_shape[1] + initializer = _dqn_default_initializer(num_input_units) + layer = hk.Conv2D( + num_features, + kernel_shape=kernel_shape, + stride=stride, + w_init=initializer, + b_init=initializer, + padding='VALID', + name=name) + return layer(inputs) + + return net_fn + + +def linear(num_outputs: int, with_bias=True, name=None) -> NetworkFn: + """Linear layer with DQN's legacy weight initialization scheme.""" + + def net_fn(inputs): + """Function representing linear layer with DQN's legacy initialization.""" + initializer = _dqn_default_initializer(inputs.shape[-1]) + layer = hk.Linear( + num_outputs, + with_bias=with_bias, + w_init=initializer, + b_init=initializer, + name=name) + return layer(inputs) + + return net_fn + + +def linear_with_shared_bias(num_outputs: int, name=None) -> NetworkFn: + """Linear layer with single shared bias instead of one bias per output.""" + + def layer_fn(inputs): + """Function representing a linear layer with single shared bias.""" + initializer = _dqn_default_initializer(inputs.shape[-1]) + bias_free_linear = hk.Linear( + num_outputs, with_bias=False, w_init=initializer, name=name) + linear_output = bias_free_linear(inputs) + bias = hk.get_parameter('b', [1], inputs.dtype, init=initializer) + bias = jnp.broadcast_to(bias, linear_output.shape) + return linear_output + bias + + return layer_fn + + +def dqn_torso() -> NetworkFn: + """DQN convolutional torso. + + Includes scaling from [`0`, `255`] (`uint8`) to [`0`, `1`] (`float32`)`. + + Returns: + Network function that `haiku.transform` can be called on. + """ + + def net_fn(inputs): + """Function representing convolutional torso for a DQN Q-network.""" + network = hk.Sequential([ + lambda x: x.astype(jnp.float32) / 255., + conv(32, kernel_shape=(8, 8), stride=(4, 4), name='conv1'), + jax.nn.relu, + conv(64, kernel_shape=(4, 4), stride=(2, 2), name='conv2'), + jax.nn.relu, + conv(64, kernel_shape=(3, 3), stride=(1, 1), name='conv3'), + jax.nn.relu, + hk.Flatten(), + ]) + return network(inputs) + + return net_fn + + +def dqn_value_head(num_actions: int, shared_bias: bool = False) -> NetworkFn: + """Regular DQN Q-value head with single hidden layer.""" + + last_layer = linear_with_shared_bias if shared_bias else linear + + def net_fn(inputs): + """Function representing value head for a DQN Q-network.""" + network = hk.Sequential([ + linear(512, name='linear1'), + jax.nn.relu, + last_layer(num_actions, name='output'), + ]) + return network(inputs) + + return net_fn + + +def qr_atari_network(num_actions: int, quantiles: jnp.ndarray) -> NetworkFn: + """QR-DQN network, expects `uint8` input.""" + + chex.assert_rank(quantiles, 1) + num_quantiles = len(quantiles) + + def net_fn(inputs): + """Function representing QR-DQN Q-network.""" + network = hk.Sequential([ + dqn_torso(), + dqn_value_head(num_quantiles * num_actions), + ]) + network_output = network(inputs) + q_dist = jnp.reshape(network_output, (-1, num_quantiles, num_actions)) + q_values = jnp.mean(q_dist, axis=1) + q_values = jax.lax.stop_gradient(q_values) + return QRNetworkOutputs(q_dist=q_dist, q_values=q_values) + + return net_fn + + +def double_dqn_atari_network(num_actions: int) -> NetworkFn: + """DQN network with shared bias in final layer, expects `uint8` input.""" + + def net_fn(inputs): + """Function representing DQN Q-network with shared bias output layer.""" + network = hk.Sequential([ + dqn_torso(), + dqn_value_head(num_actions, shared_bias=True), + ]) + return QNetworkOutputs(q_values=network(inputs)) + + return net_fn + + +def make_network(network_type: str, num_actions: int) -> Network: + """Constructs network.""" + if network_type == 'double_q': + network_fn = double_dqn_atari_network(num_actions) + elif network_type == 'qr': + quantiles = make_quantiles() + network_fn = qr_atari_network(num_actions, quantiles) + else: + raise ValueError('Unknown network "{}"'.format(network_type)) + + return hk.transform(network_fn) diff --git a/tandem_dqn/parts.py b/tandem_dqn/parts.py new file mode 100644 index 0000000..57d29e1 --- /dev/null +++ b/tandem_dqn/parts.py @@ -0,0 +1,487 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Components for DQN.""" + +import abc +import collections +import csv +import os +import timeit +from typing import Any, Iterable, Mapping, Optional, Text, Tuple, Union + +import dm_env +import jax +import jax.numpy as jnp +import numpy as np +import rlax + +from tandem_dqn import networks +from tandem_dqn import processors + +Action = int +Network = networks.Network +NetworkParams = networks.Params +PRNGKey = jnp.ndarray # A size 2 array. + + +class Agent(abc.ABC): + """Agent interface.""" + + @abc.abstractmethod + def step(self, timestep: dm_env.TimeStep) -> Action: + """Selects action given timestep and potentially learns.""" + + @abc.abstractmethod + def reset(self) -> None: + """Resets the agent's episodic state such as frame stack and action repeat. + + This method should be called at the beginning of every episode. + """ + + @abc.abstractmethod + def get_state(self) -> Mapping[Text, Any]: + """Retrieves agent state as a dictionary (e.g. for serialization).""" + + @abc.abstractmethod + def set_state(self, state: Mapping[Text, Any]) -> None: + """Sets agent state from a (potentially de-serialized) dictionary.""" + + @property + @abc.abstractmethod + def statistics(self) -> Mapping[Text, float]: + """Returns current agent statistics as a dictionary.""" + + +def run_loop( + agent: Agent, + environment: dm_env.Environment, + max_steps_per_episode: int = 0, + yield_before_reset: bool = False, +) -> Iterable[Tuple[dm_env.Environment, Optional[dm_env.TimeStep], Agent, + Optional[Action]]]: + """Repeatedly alternates step calls on environment and agent. + + At time `t`, `t + 1` environment timesteps and `t + 1` agent steps have been + seen in the current episode. `t` resets to `0` for the next episode. + + Args: + agent: Agent to be run, has methods `step(timestep)` and `reset()`. + environment: Environment to run, has methods `step(action)` and `reset()`. + max_steps_per_episode: If positive, when time t reaches this value within an + episode, the episode is truncated. + yield_before_reset: Whether to additionally yield `(environment, None, + agent, None)` before the agent and environment is reset at the start of + each episode. + + Yields: + Tuple `(environment, timestep_t, agent, a_t)` where + `a_t = agent.step(timestep_t)`. + """ + while True: # For each episode. + if yield_before_reset: + yield environment, None, agent, None, + + t = 0 + agent.reset() + timestep_t = environment.reset() # timestep_0. + + while True: # For each step in the current episode. + a_t = agent.step(timestep_t) + yield environment, timestep_t, agent, a_t + + # Update t after one environment step and agent step and relabel. + t += 1 + a_tm1 = a_t + timestep_t = environment.step(a_tm1) + + if max_steps_per_episode > 0 and t >= max_steps_per_episode: + assert t == max_steps_per_episode + timestep_t = timestep_t._replace(step_type=dm_env.StepType.LAST) + + if timestep_t.last(): + unused_a_t = agent.step(timestep_t) # Extra agent step, action ignored. + yield environment, timestep_t, agent, None + break + + +def generate_statistics( + trackers: Iterable[Any], + timestep_action_sequence: Iterable[Tuple[dm_env.Environment, + Optional[dm_env.TimeStep], Agent, + Optional[Action]]] +) -> Mapping[Text, Any]: + """Generates statistics from a sequence of timestep and actions.""" + # Only reset at the start, not between episodes. + for tracker in trackers: + tracker.reset() + + for environment, timestep_t, agent, a_t in timestep_action_sequence: + for tracker in trackers: + tracker.step(environment, timestep_t, agent, a_t) + + # Merge all statistics dictionaries into one. + statistics_dicts = (tracker.get() for tracker in trackers) + return dict(collections.ChainMap(*statistics_dicts)) + + +class EpisodeTracker: + """Tracks episode return and other statistics.""" + + def __init__(self): + self._num_steps_since_reset = None + self._num_steps_over_episodes = None + self._episode_returns = None + self._current_episode_rewards = None + self._current_episode_step = None + + def step( + self, + environment: Optional[dm_env.Environment], + timestep_t: dm_env.TimeStep, + agent: Optional[Agent], + a_t: Optional[Action], + ) -> None: + """Accumulates statistics from timestep.""" + del (environment, agent, a_t) + + if timestep_t.first(): + if self._current_episode_rewards: + raise ValueError('Current episode reward list should be empty.') + if self._current_episode_step != 0: + raise ValueError('Current episode step should be zero.') + else: + # First reward is invalid, all other rewards are appended. + self._current_episode_rewards.append(timestep_t.reward) + + self._num_steps_since_reset += 1 + self._current_episode_step += 1 + + if timestep_t.last(): + self._episode_returns.append(sum(self._current_episode_rewards)) + self._current_episode_rewards = [] + self._num_steps_over_episodes += self._current_episode_step + self._current_episode_step = 0 + + def reset(self) -> None: + """Resets all gathered statistics, not to be called between episodes.""" + self._num_steps_since_reset = 0 + self._num_steps_over_episodes = 0 + self._episode_returns = [] + self._current_episode_step = 0 + self._current_episode_rewards = [] + + def get(self) -> Mapping[Text, Union[int, float, None]]: + """Aggregates statistics and returns as a dictionary. + + Here the convention is `episode_return` is set to `current_episode_return` + if a full episode has not been encountered. Otherwise it is set to + `mean_episode_return` which is the mean return of complete episodes only. If + no steps have been taken at all, `episode_return` is set to `NaN`. + + Returns: + A dictionary of aggregated statistics. + """ + if self._episode_returns: + mean_episode_return = np.array(self._episode_returns).mean() + current_episode_return = sum(self._current_episode_rewards) + episode_return = mean_episode_return + else: + mean_episode_return = np.nan + if self._num_steps_since_reset > 0: + current_episode_return = sum(self._current_episode_rewards) + else: + current_episode_return = np.nan + episode_return = current_episode_return + + return { + 'mean_episode_return': mean_episode_return, + 'current_episode_return': current_episode_return, + 'episode_return': episode_return, + 'num_episodes': len(self._episode_returns), + 'num_steps_over_episodes': self._num_steps_over_episodes, + 'current_episode_step': self._current_episode_step, + 'num_steps_since_reset': self._num_steps_since_reset, + } + + +class StepRateTracker: + """Tracks step rate, number of steps taken and duration since last reset.""" + + def __init__(self): + self._num_steps_since_reset = None + self._start = None + + def step( + self, + environment: Optional[dm_env.Environment], + timestep_t: Optional[dm_env.TimeStep], + agent: Optional[Agent], + a_t: Optional[Action], + ) -> None: + del (environment, timestep_t, agent, a_t) + self._num_steps_since_reset += 1 + + def reset(self) -> None: + self._num_steps_since_reset = 0 + self._start = timeit.default_timer() + + def get(self) -> Mapping[Text, float]: + duration = timeit.default_timer() - self._start + if self._num_steps_since_reset > 0: + step_rate = self._num_steps_since_reset / duration + else: + step_rate = np.nan + return { + 'step_rate': step_rate, + 'num_steps': self._num_steps_since_reset, + 'duration': duration, + } + + +class UnbiasedExponentialWeightedAverageAgentTracker: + """'Unbiased Constant-Step-Size Trick' from the Sutton and Barto RL book.""" + + def __init__(self, step_size: float, initial_agent: Agent): + self._initial_statistics = dict(initial_agent.statistics) + self._step_size = step_size + self.trace = 0. + self._statistics = dict(self._initial_statistics) + + def step( + self, + environment: Optional[dm_env.Environment], + timestep_t: Optional[dm_env.TimeStep], + agent: Agent, + a_t: Optional[Action], + ) -> None: + """Accumulates agent statistics.""" + del (environment, timestep_t, a_t) + + self.trace = (1 - self._step_size) * self.trace + self._step_size + final_step_size = self._step_size / self.trace + assert 0 <= final_step_size <= 1 + + if final_step_size == 1: + # Since the self._initial_statistics is likely to be NaN and + # 0 * NaN == NaN just replace self._statistics on the first step. + self._statistics = dict(agent.statistics) + else: + self._statistics = jax.tree_multimap( + lambda s, x: (1 - final_step_size) * s + final_step_size * x, + self._statistics, agent.statistics) + + def reset(self) -> None: + """Resets statistics and internal state.""" + self.trace = 0. + # get() may be called before step() so ensure statistics are initialized. + self._statistics = dict(self._initial_statistics) + + def get(self) -> Mapping[Text, float]: + """Returns current accumulated statistics.""" + return self._statistics + + +def make_default_trackers(initial_agent: Agent): + return [ + EpisodeTracker(), + StepRateTracker(), + UnbiasedExponentialWeightedAverageAgentTracker( + step_size=1e-3, initial_agent=initial_agent), + ] + + +class EpsilonGreedyActor(Agent): + """Agent that acts with a given set of Q-network parameters and epsilon. + + Network parameters are set on the actor. The actor can be serialized, + ensuring determinism of execution (e.g. when checkpointing). + """ + + def __init__( + self, + preprocessor: processors.Processor, + network: Network, + exploration_epsilon: float, + rng_key: PRNGKey, + ): + self._preprocessor = preprocessor + self._rng_key = rng_key + self._action = None + self.network_params = None # Nest of arrays (haiku.Params), set externally. + + def select_action(rng_key, network_params, s_t): + """Samples action from eps-greedy policy wrt Q-values at given state.""" + rng_key, apply_key, policy_key = jax.random.split(rng_key, 3) + q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0] + a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon) + return rng_key, a_t + + self._select_action = jax.jit(select_action) + + def step(self, timestep: dm_env.TimeStep) -> Action: + """Selects action given a timestep.""" + timestep = self._preprocessor(timestep) + + if timestep is None: # Repeat action. + return self._action + + s_t = timestep.observation + self._rng_key, a_t = self._select_action(self._rng_key, self.network_params, + s_t) + self._action = Action(jax.device_get(a_t)) + return self._action + + def reset(self) -> None: + """Resets the agent's episodic state such as frame stack and action repeat. + + This method should be called at the beginning of every episode. + """ + processors.reset(self._preprocessor) + self._action = None + + def get_state(self) -> Mapping[Text, Any]: + """Retrieves agent state as a dictionary (e.g. for serialization).""" + # State contains network params to make agent easy to run from a checkpoint. + return { + 'rng_key': self._rng_key, + 'network_params': self.network_params, + } + + def set_state(self, state: Mapping[Text, Any]) -> None: + """Sets agent state from a (potentially de-serialized) dictionary.""" + self._rng_key = state['rng_key'] + self.network_params = state['network_params'] + + @property + def statistics(self) -> Mapping[Text, float]: + return {} + + +class LinearSchedule: + """Linear schedule, used for exploration epsilon in DQN agents.""" + + def __init__(self, + begin_value, + end_value, + begin_t, + end_t=None, + decay_steps=None): + if (end_t is None) == (decay_steps is None): + raise ValueError('Exactly one of end_t, decay_steps must be provided.') + self._decay_steps = decay_steps if end_t is None else end_t - begin_t + self._begin_t = begin_t + self._begin_value = begin_value + self._end_value = end_value + + def __call__(self, t): + """Implements a linear transition from a begin to an end value.""" + frac = min(max(t - self._begin_t, 0), self._decay_steps) / self._decay_steps + return (1 - frac) * self._begin_value + frac * self._end_value + + +class NullWriter: + """A placeholder logging object that does nothing.""" + + def write(self, *args, **kwargs) -> None: + pass + + def close(self) -> None: + pass + + +class CsvWriter: + """A logging object writing to a CSV file. + + Each `write()` takes a `OrderedDict`, creating one column in the CSV file for + each dictionary key on the first call. Successive calls to `write()` must + contain the same dictionary keys. + """ + + def __init__(self, fname: Text): + """Initializes a `CsvWriter`. + + Args: + fname: File name (path) for file to be written to. + """ + dirname = os.path.dirname(fname) + if not os.path.exists(dirname): + os.makedirs(dirname) + self._fname = fname + self._header_written = False + self._fieldnames = None + + def write(self, values: collections.OrderedDict) -> None: + """Appends given values as new row to CSV file.""" + if self._fieldnames is None: + self._fieldnames = values.keys() + # Open a file in 'append' mode, so we can continue logging safely to the + # same file after e.g. restarting from a checkpoint. + with open(self._fname, 'a') as file: + # Always use same fieldnames to create writer, this way a consistency + # check is performed automatically on each write. + writer = csv.DictWriter(file, fieldnames=self._fieldnames) + # Write a header if this is the very first write. + if not self._header_written: + writer.writeheader() + self._header_written = True + writer.writerow(values) + + def close(self) -> None: + """Closes the `CsvWriter`.""" + pass + + def get_state(self) -> Mapping[Text, Any]: + """Retrieves `CsvWriter` state as a `dict` (e.g. for serialization).""" + return { + 'header_written': self._header_written, + 'fieldnames': self._fieldnames + } + + def set_state(self, state: Mapping[Text, Any]) -> None: + """Sets `CsvWriter` state from a (potentially de-serialized) dictionary.""" + self._header_written = state['header_written'] + self._fieldnames = state['fieldnames'] + + +class NullCheckpoint: + """A placeholder checkpointing object that does nothing. + + Can be used as a substitute for an actual checkpointing object when + checkpointing is disabled. + """ + + def __init__(self): + self.state = AttributeDict() + + def save(self) -> None: + pass + + def can_be_restored(self) -> bool: + return False + + def restore(self) -> None: + pass + + +class AttributeDict(dict): + """A `dict` that supports getting, setting, deleting keys via attributes.""" + + def __getattr__(self, key): + return self[key] + + def __setattr__(self, key, value): + self[key] = value + + def __delattr__(self, key): + del self[key] diff --git a/tandem_dqn/processors.py b/tandem_dqn/processors.py new file mode 100644 index 0000000..56f5b32 --- /dev/null +++ b/tandem_dqn/processors.py @@ -0,0 +1,604 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Composable timestep processing, for DQN Atari preprocessing. + +Aims: +* Be self-contained. +* Easy to have the preprocessing on the agent side or on the environment side. +* Easy to swap out and modify parts of the processing. + +Conventions: +* The term "processor" is used to refer to any callable that could also have + a `reset()` function to clear any internal state. E.g. a plain function. Or an + instance of a class with `__call__` method, with or without a `reset()` + method. +* `None` means no output when subsampling inputs. +""" + +import collections +from typing import Any, Callable, List, Iterable, Optional, Sequence, Text, Tuple + +import dm_env +from dm_env import specs +import numpy as np +from PIL import Image + +Processor = Callable # Actually a callable that may also have a reset() method. +Nest = Any # Recursive types are not yet supported by pytype. +NamedTuple = Any +StepType = dm_env.StepType + + +def reset(processor: Processor[[Any], Any]) -> None: + """Calls `reset()` on a `Processor` or function if the method exists.""" + if hasattr(processor, 'reset'): + processor.reset() + + +identity = lambda v: v + + +def trailing_zero_pad( + length: int) -> Processor[[List[np.ndarray]], List[np.ndarray]]: + """Adds trailing zero padding to array lists to ensure a minimum length.""" + + def trailing_zero_pad_fn(arrays): + padding_length = length - len(arrays) + if padding_length <= 0: + return arrays + zero = np.zeros_like(arrays[0]) + return arrays + [zero] * padding_length + + return trailing_zero_pad_fn + + +def none_to_zero_pad(values: List[Optional[NamedTuple]]) -> List[NamedTuple]: + """Replaces `None`s in a list of named tuples with zeros of same structure.""" + + actual_values = [n for n in values if n is not None] + if not actual_values: + raise ValueError('Must have at least one value which is not None.') + if len(actual_values) == len(values): + return values + example = actual_values[0] + zero = type(example)(*(np.zeros_like(x) for x in example)) + return [zero if v is None else v for v in values] + + +def named_tuple_sequence_stack(values: Sequence[NamedTuple]) -> NamedTuple: + """Converts a sequence of named tuples into a named tuple of tuples.""" + # [T(1, 2), T(3, 4), T(5, 6)]. + transposed = zip(*values) + # ((1, 3, 5), (2, 4, 6)). + return type(values[0])(*transposed) + # T((1, 3, 5), (2, 4, 6)). + + +class Deque: + """Double ended queue with a maximum length and initial values.""" + + def __init__(self, max_length: int, initial_values=None): + self._deque = collections.deque(maxlen=max_length) + self._initial_values = initial_values or [] + + def reset(self) -> None: + self._deque.clear() + self._deque.extend(self._initial_values) + + def __call__(self, value: Any) -> collections.deque: + self._deque.append(value) + return self._deque + + +class FixedPaddedBuffer: + """Fixed size `None`-padded buffer which is cleared after it is filled. + + E.g. with `length = 3`, `initial_index = 2` and values `[0, 1, 2, 3, 4, 5, 6]` + this will return `~~0`, `1~~`, `12~`, `123`, `4~~`, `45~`, `456`, where `~` + represents `None`. Used to concatenate timesteps for action repeats. + + Action repeat requirements are: + * Fixed size buffer of timesteps. + * The `FIRST` timestep should return immediately to get the first action of + the episode, as there is no preceding action to repeat. Prefix with padding. + * For `MID` timesteps, the timestep buffer is periodically returned when full. + * When a `LAST` timestep is encountered, the current buffer of timesteps is + returned, suffixed with padding, as buffers should not cross episode + boundaries. + + The requirements can be fulfilled by conditionally subsampling the output of + this processor. + """ + + def __init__(self, length: int, initial_index: int): + self._length = length + self._initial_index = initial_index % length + + self._index = self._initial_index + self._buffer = [None] * self._length + + def reset(self) -> None: + self._index = self._initial_index + self._buffer = [None] * self._length + + def __call__(self, value: Any) -> Sequence[Any]: + if self._index >= self._length: + assert self._index == self._length + self._index = 0 + self._buffer = [None] * self._length + self._buffer[self._index] = value + self._index += 1 + return self._buffer + + +class ConditionallySubsample: + """Conditionally passes through input, returning `None` otherwise.""" + + def __init__(self, condition: Processor[[Any], bool]): + self._condition = condition + + def reset(self) -> None: + reset(self._condition) + + def __call__(self, value: Any) -> Optional[Any]: + return value if self._condition(value) else None + + +class TimestepBufferCondition: + """Returns `True` when an iterable of timesteps should be passed on. + + Specifically returns `True`: + * If timesteps contain a `FIRST`. + * If timesteps contain a `LAST`. + * If number of steps passed since `FIRST` timestep modulo `period` is `0`. + + Returns `False` otherwise. Used for action repeats in Atari preprocessing. + """ + + def __init__(self, period: int): + self._period = period + self._steps_since_first_timestep = None + self._should_reset = False + + def reset(self): + self._should_reset = False + self._steps_since_first_timestep = None + + def __call__(self, timesteps: Iterable[dm_env.TimeStep]) -> bool: + if self._should_reset: + raise RuntimeError('Should have reset.') + + # Find the main step type, FIRST and LAST take precedence over MID. + main_step_type = StepType.MID + precedent_step_types = (StepType.FIRST, StepType.LAST) + for timestep in timesteps: + if timestep is None: + continue + if timestep.step_type in precedent_step_types: + if main_step_type in precedent_step_types: + raise RuntimeError('Expected at most one FIRST or LAST.') + main_step_type = timestep.step_type + + # Must have FIRST timestep after a reset. + if self._steps_since_first_timestep is None: + if main_step_type != StepType.FIRST: + raise RuntimeError('After reset first timestep should be FIRST.') + + # pytype: disable=unsupported-operands + if main_step_type == StepType.FIRST: + self._steps_since_first_timestep = 0 + return True + elif main_step_type == StepType.LAST: + self._steps_since_first_timestep = None + self._should_reset = True + return True + elif (self._steps_since_first_timestep + 1) % self._period == 0: + self._steps_since_first_timestep += 1 + return True + else: + self._steps_since_first_timestep += 1 + return False + # pytype: enable=unsupported-operands + + +class ApplyToNamedTupleField: + """Runs processors on a particular field of a named tuple.""" + + def __init__(self, field: Text, *processors: Processor[[Any], Any]): + self._field = field + self._processors = processors + + def reset(self) -> None: + for processor in self._processors: + reset(processor) + + def __call__(self, value: NamedTuple) -> NamedTuple: + attr_value = getattr(value, self._field) + for processor in self._processors: + attr_value = processor(attr_value) + + return value._replace(**{self._field: attr_value}) + + +class Maybe: + """Wraps another processor so that `None` is returned when `None` is input.""" + + def __init__(self, processor: Processor[[Any], Any]): + self._processor = processor + + def reset(self) -> None: + reset(self._processor) + + def __call__(self, value: Optional[Any]) -> Optional[Any]: + if value is None: + return None + else: + return self._processor(value) + + +class Sequential: + """Chains together multiple processors.""" + + def __init__(self, *processors: Processor[[Any], Any]): + self._processors = processors + + def reset(self) -> None: + for processor in self._processors: + reset(processor) + + def __call__(self, value: Any) -> Any: + for processor in self._processors: + value = processor(value) + return value + + +class ZeroDiscountOnLifeLoss: + """Sets discount to zero on timestep if number of lives has decreased. + + This processor assumes observations to be tuples whose second entry is a + scalar indicating the remaining number of lives. + """ + + def __init__(self): + self._num_lives_on_prev_step = None + + def reset(self) -> None: + self._num_lives_on_prev_step = None + + def __call__(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: + # We have a life loss when the timestep is a regular transition and lives + # have decreased since the previous timestep. + num_lives = timestep.observation[1] + life_lost = timestep.mid() and (num_lives < self._num_lives_on_prev_step) + self._num_lives_on_prev_step = num_lives + return timestep._replace(discount=0.) if life_lost else timestep + + +def reduce_step_type(step_types: Sequence[StepType], + debug: bool = False) -> StepType: + """Outputs a representative step type from an array of step types.""" + # Zero padding will appear to be FIRST. Padding should only be seen before the + # FIRST (e.g. 000F) or after LAST (e.g. ML00). + if debug: + np_step_types = np.array(step_types) + output_step_type = StepType.MID + for i, step_type in enumerate(step_types): + if step_type == 0: # step_type not actually FIRST, but we do expect 000F. + if debug and not (np_step_types == 0).all(): + raise ValueError('Expected zero padding followed by FIRST.') + output_step_type = StepType.FIRST + break + elif step_type == StepType.LAST: + output_step_type = StepType.LAST + if debug and not (np_step_types[i + 1:] == 0).all(): + raise ValueError('Expected LAST to be followed by zero padding.') + break + else: + if step_type != StepType.MID: + raise ValueError('Expected MID if not FIRST or LAST.') + return output_step_type + + +def aggregate_rewards(rewards: Sequence[Optional[float]], + debug: bool = False) -> Optional[float]: + """Sums up rewards, assumes discount is 1.""" + if None in rewards: + if debug: + np_rewards = np.array(rewards) + if not (np_rewards[-1] is None and (np_rewards[:-1] == 0).all()): + # Should only ever have [0, 0, 0, None] due to zero padding. + raise ValueError('Should only have a None reward for FIRST.') + return None + else: + # Faster than np.sum for a list of floats. + return sum(rewards) + + +def aggregate_discounts(discounts: Sequence[Optional[float]], + debug: bool = False) -> Optional[float]: + """Aggregates array of discounts into a scalar, expects `0`, `1` or `None`.""" + if debug: + np_discounts = np.array(discounts) + if not np.isin(np_discounts, [0., 1., None]).all(): + raise ValueError('All discounts should be 0 or 1, got: %s.' % + np_discounts) + if None in discounts: + if debug: + if not (np_discounts[-1] is None and (np_discounts[:-1] == 0).all()): + # Should have [0, 0, 0, None] due to zero padding. + raise ValueError('Should only have a None discount for FIRST.') + return None + else: + # Faster than np.prod for a list of floats. + result = 1 + for d in discounts: + result *= d + return result + + +def rgb2y(array: np.ndarray) -> np.ndarray: + """Converts RGB image array into grayscale.""" + if array.ndim != 3: + raise ValueError('Input array should be 3D, got %s.' % array.ndim) + output = np.tensordot(array, [0.299, 0.587, 1 - (0.299 + 0.587)], (-1, 0)) + return output.astype(np.uint8) + + +def resize(shape: Tuple[int, ...]) -> Processor[[np.ndarray], np.ndarray]: + """Resizes array to the given shape.""" + if len(shape) != 2: + raise ValueError('Resize shape has to be 2D, given: %s.' % str(shape)) + # Image.resize takes (width, height) as output_shape argument. + image_shape = (shape[1], shape[0]) + + def resize_fn(array): + image = Image.fromarray(array).resize(image_shape, Image.BILINEAR) + return np.array(image, dtype=np.uint8) + + return resize_fn + + +def select_rgb_observation(timestep: dm_env.TimeStep) -> dm_env.TimeStep: + """Replaces an observation tuple by its first entry (the RGB observation).""" + return timestep._replace(observation=timestep.observation[0]) + + +def apply_additional_discount( + additional_discount: float) -> Processor[[float], float]: + """Returns a function that scales its non-`None` input by a constant.""" + return lambda d: None if d is None else additional_discount * d + + +def clip_reward(bound: float) -> Processor[[Optional[float]], Optional[float]]: + """Returns a function that clips non-`None` inputs to (`-bound`, `bound`).""" + + def clip_reward_fn(reward): + return None if reward is None else max(min(reward, bound), -bound) + + return clip_reward_fn + + +def show(prefix: Text) -> Processor[[Any], Any]: + """Prints value and passes through, for debugging.""" + + def show_fn(value): + print('%s: %s' % (prefix, value)) + return value + + return show_fn + + +def atari( + additional_discount: float = 0.99, + max_abs_reward: Optional[float] = 1.0, + resize_shape: Optional[Tuple[int, int]] = (84, 84), + num_action_repeats: int = 4, + num_pooled_frames: int = 2, + zero_discount_on_life_loss: bool = True, + num_stacked_frames: int = 4, + grayscaling: bool = True, +) -> Processor[[dm_env.TimeStep], Optional[dm_env.TimeStep]]: + """Standard DQN preprocessing on Atari.""" + + # This processor does the following to a sequence of timesteps. + # + # 1. Zeroes discount on loss of life. + # 2. Repeats actions (previous action should be repeated if None is returned). + # 3. Max pools action repeated observations. + # 4. Grayscales observations. + # 5. Resizes observations. + # 6. Stacks observations. + # 7. Clips rewards. + # 8. Applies an additional discount. + # + # For more detail see the annotations in the processors below. + + # The FixedPaddedBuffer, ConditionallySubsample, none_to_zero_pad, stack and + # max_pool on the observation collectively does this (step types: F = FIRST, + # M = MID, L = LAST, ~ is None): + # + # Type: F | M M M M | M M L | F | + # Frames: A | B C D E | F G H | I | + # Output: max[0A]| ~ ~ ~ max[DE]| ~ ~ max[H0]|max[0I]| + return Sequential( + # When the number of lives decreases, set discount to 0. + ZeroDiscountOnLifeLoss() if zero_discount_on_life_loss else identity, + # Select the RGB observation as the main observation, dropping lives. + select_rgb_observation, + # obs: 1, 2, 3, 4, 5, 6, 7, 8, 9, ... + # Write timesteps into a fixed-sized buffer padded with None. + FixedPaddedBuffer(length=num_action_repeats, initial_index=-1), + # obs: ~~~1, 2~~~, 23~~, 234~, 2345, 6~~~, 67~~, 678~, 6789, ... + # Periodically return the deque of timesteps, when the current timestep is + # FIRST, after that every 4 steps, and when the current timestep is LAST. + ConditionallySubsample(TimestepBufferCondition(num_action_repeats)), + # obs: ~~~1, ~, ~, ~, 2345, ~, ~, ~, 6789, ... + # If None pass through, otherwise apply the processor. + Maybe( + Sequential( + # Replace Nones with zero padding in each buffer. + none_to_zero_pad, + # obs: 0001, ~, ~, ~, 2345, ~, ~, ~, 6789, ... + # Convert sequence of nests into a nest of sequences. + named_tuple_sequence_stack, + # Choose representative step type from an array of step types. + ApplyToNamedTupleField('step_type', reduce_step_type), + # Rewards: sum then clip. + ApplyToNamedTupleField( + 'reward', + aggregate_rewards, + clip_reward(max_abs_reward) if max_abs_reward else identity, + ), + # Discounts: take product and scale by an additional discount. + ApplyToNamedTupleField( + 'discount', + aggregate_discounts, + apply_additional_discount(additional_discount), + ), + # Observations: max pool, grayscale, resize, and stack. + ApplyToNamedTupleField( + 'observation', + lambda obs: np.stack(obs[-num_pooled_frames:], axis=0), + lambda obs: np.max(obs, axis=0), + # obs: max[01], ~, ~, ~, max[45], ~, ~, ~, max[89], ... + # obs: A, ~, ~, ~, B, ~, ~, ~, C, ... + rgb2y if grayscaling else identity, + resize(resize_shape) if resize_shape else identity, + Deque(max_length=num_stacked_frames), + # obs: A, ~, ~, ~, AB, ~, ~, ~, ABC, ~, ~, ~, ABCD, ~, ~, ~, + # BCDE, ~, ~, ~, CDEF, ... + list, + trailing_zero_pad(length=num_stacked_frames), + # obs: A000, ~, ~, ~, AB00, ~, ~, ~, ABC0, ~, ~, ~, ABCD, + # ~, ~, ~, BCDE, ... + lambda obs: np.stack(obs, axis=-1), + ), + )), + ) + + +class AtariEnvironmentWrapper(dm_env.Environment): + """Python environment wrapper that provides DQN Atari preprocessing. + + This is a thin wrapper around the Atari processor. + + Expects underlying Atari environment to have interleaved pixels (HWC) and + zero-indexed actions. + """ + + def __init__( + self, + environment: dm_env.Environment, + additional_discount: float = 0.99, + max_abs_reward: Optional[float] = 1.0, + resize_shape: Optional[Tuple[int, int]] = (84, 84), + num_action_repeats: int = 4, + num_pooled_frames: int = 2, + zero_discount_on_life_loss: bool = True, + num_stacked_frames: int = 4, + grayscaling: bool = True, + ): + rgb_spec, unused_lives_spec = environment.observation_spec() + if rgb_spec.shape[2] != 3: + raise ValueError( + 'This wrapper assumes interleaved pixel observations with shape ' + '(height, width, channels).') + if int(environment.action_spec().minimum) != 0: + raise ValueError('This wrapper assumes zero-indexed actions.') + + self._environment = environment + self._processor = atari( + additional_discount=additional_discount, + max_abs_reward=max_abs_reward, + resize_shape=resize_shape, + num_action_repeats=num_action_repeats, + num_pooled_frames=num_pooled_frames, + zero_discount_on_life_loss=zero_discount_on_life_loss, + num_stacked_frames=num_stacked_frames, + grayscaling=grayscaling, + ) + + if grayscaling: + self._observation_shape = resize_shape + (num_stacked_frames,) + self._observation_spec_name = 'grayscale' + else: + self._observation_shape = resize_shape + (3, num_stacked_frames) + self._observation_spec_name = 'RGB' + + self._reset_next_step = True + + def reset(self) -> dm_env.TimeStep: + """Resets environment and provides the first processed timestep.""" + reset(self._processor) + timestep = self._environment.reset() + processed_timestep = self._processor(timestep) + assert processed_timestep is not None + self._reset_next_step = False + return processed_timestep + + def step(self, action: int) -> dm_env.TimeStep: + """Steps up to `num_action_repeat` times, returns a processed timestep.""" + # This implements the action repeat by repeatedly passing in the last action + # until an actual timestep is returned by the processor. + + if self._reset_next_step: + return self.reset() # Ignore action. + + processed_timestep = None + while processed_timestep is None: + timestep = self._environment.step(action) + processed_timestep = self._processor(timestep) + if timestep.last(): + self._reset_next_step = True + assert processed_timestep is not None + return processed_timestep + + def action_spec(self) -> specs.DiscreteArray: + return self._environment.action_spec() + + def observation_spec(self) -> specs.Array: + return specs.Array( + shape=self._observation_shape, + dtype=np.uint8, + name=self._observation_spec_name) + + +class AtariSimpleActionEnvironmentWrapper(dm_env.Environment): + """Python environment wrapper for Atari so it takes integer actions. + + Use this when processing is done on the agent side. + """ + + def __init__(self, environment: dm_env.Environment): + self._environment = environment + if int(environment.action_spec()[0].minimum) != 0: + raise ValueError( + 'This wrapper assumes zero-indexed actions. Use the Atari setting ' + 'zero_indexed_actions=\"true\" to get actions in this format.') + + def reset(self) -> dm_env.TimeStep: + return self._environment.reset() + + def step(self, action: int) -> dm_env.TimeStep: + return self._environment.step([np.array(action).reshape((1,))]) + + def action_spec(self) -> specs.DiscreteArray: + action_spec = self._environment.action_spec()[0] + return specs.DiscreteArray( + num_values=action_spec.maximum.item() + 1, + dtype=action_spec.dtype, + name='action_spec') + + def observation_spec(self) -> specs.Array: + return self._environment.observation_spec() diff --git a/tandem_dqn/replay.py b/tandem_dqn/replay.py new file mode 100644 index 0000000..e919690 --- /dev/null +++ b/tandem_dqn/replay.py @@ -0,0 +1,167 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Replay components for DQN-type agents.""" + +import collections +import typing +from typing import Any, Callable, Generic, Iterable, List, Mapping, Optional, Sequence, Text, Tuple, TypeVar + +import dm_env +import numpy as np +import snappy + +from tandem_dqn import parts + +CompressedArray = Tuple[bytes, Tuple, np.dtype] + +# Generic replay structure: Any flat named tuple. +ReplayStructure = TypeVar('ReplayStructure', bound=Tuple[Any, ...]) + + +class Transition(typing.NamedTuple): + s_tm1: Optional[np.ndarray] + a_tm1: Optional[parts.Action] + r_t: Optional[float] + discount_t: Optional[float] + s_t: Optional[np.ndarray] + a_t: Optional[parts.Action] = None + mc_return_tm1: Optional[float] = None + + +class TransitionReplay(Generic[ReplayStructure]): + """Uniform replay, with circular buffer storage for flat named tuples.""" + + def __init__(self, + capacity: int, + structure: ReplayStructure, + random_state: np.random.RandomState, + encoder: Optional[Callable[[ReplayStructure], Any]] = None, + decoder: Optional[Callable[[Any], ReplayStructure]] = None): + self._capacity = capacity + self._structure = structure + self._random_state = random_state + self._encoder = encoder or (lambda s: s) + self._decoder = decoder or (lambda s: s) + + self._storage = [None] * capacity + self._num_added = 0 + + def add(self, item: ReplayStructure) -> None: + """Adds single item to replay.""" + self._storage[self._num_added % self._capacity] = self._encoder(item) + self._num_added += 1 + + def get(self, indices: Sequence[int]) -> List[ReplayStructure]: + """Retrieves items by indices.""" + return [self._decoder(self._storage[i]) for i in indices] + + def sample(self, size: int) -> ReplayStructure: + """Samples batch of items from replay uniformly, with replacement.""" + indices = self._random_state.choice(self.size, size=size, replace=True) + samples = self.get(indices) + transposed = zip(*samples) + stacked = [np.stack(xs, axis=0) for xs in transposed] + return type(self._structure)(*stacked) # pytype: disable=not-callable + + @property + def size(self) -> int: + """Number of items currently contained in replay.""" + return min(self._num_added, self._capacity) + + @property + def capacity(self) -> int: + """Total capacity of replay (max number of items stored at any one time).""" + return self._capacity + + def get_state(self) -> Mapping[Text, Any]: + """Retrieves replay state as a dictionary (e.g. for serialization).""" + return { + 'storage': self._storage, + 'num_added': self._num_added, + } + + def set_state(self, state: Mapping[Text, Any]) -> None: + """Sets replay state from a (potentially de-serialized) dictionary.""" + self._storage = state['storage'] + self._num_added = state['num_added'] + + +class TransitionAccumulatorWithMCReturn: + """Accumulates timesteps to transitions with MC returns.""" + + def __init__(self): + self._transitions = collections.deque() + self.reset() + + def step(self, timestep_t: dm_env.TimeStep, + a_t: parts.Action) -> Iterable[Transition]: + """Accumulates timestep and resulting action, maybe yields transitions.""" + if timestep_t.first(): + self.reset() + + # There are no transitions on the first timestep. + if self._timestep_tm1 is None: + assert self._a_tm1 is None + if not timestep_t.first(): + raise ValueError('Expected FIRST timestep, got %s.' % str(timestep_t)) + self._timestep_tm1 = timestep_t + self._a_tm1 = a_t + return # Empty iterable. + + self._transitions.append( + Transition( + s_tm1=self._timestep_tm1.observation, + a_tm1=self._a_tm1, + r_t=timestep_t.reward, + discount_t=timestep_t.discount, + s_t=timestep_t.observation, + a_t=a_t, + mc_return_tm1=None, + )) + + self._timestep_tm1 = timestep_t + self._a_tm1 = a_t + + if timestep_t.last(): + # Annotate all episode transitions with their MC returns. + mc_return = 0 + mc_transitions = [] + while self._transitions: + transition = self._transitions.pop() + mc_return = transition.discount_t * mc_return + transition.r_t + mc_transitions.append(transition._replace(mc_return_tm1=mc_return)) + for transition in reversed(mc_transitions): + yield transition + + else: + # Wait for episode end before yielding anything. + return + + def reset(self) -> None: + """Resets the accumulator. Following timestep is expected to be FIRST.""" + self._transitions.clear() + self._timestep_tm1 = None + self._a_tm1 = None + + +def compress_array(array: np.ndarray) -> CompressedArray: + """Compresses a numpy array with snappy.""" + return snappy.compress(array), array.shape, array.dtype + + +def uncompress_array(compressed: CompressedArray) -> np.ndarray: + """Uncompresses a numpy array with snappy given its shape and dtype.""" + compressed_array, shape, dtype = compressed + byte_string = snappy.uncompress(compressed_array) + return np.frombuffer(byte_string, dtype=dtype).reshape(shape) diff --git a/tandem_dqn/requirements.txt b/tandem_dqn/requirements.txt new file mode 100644 index 0000000..528acd3 --- /dev/null +++ b/tandem_dqn/requirements.txt @@ -0,0 +1,38 @@ +# DeepMind libraries. +chex==0.0.2 +dm-env==1.2 +dm-haiku @ git+git://github.com/deepmind/dm-haiku.git@db991d56563221d5a06be5e7228155e53d01aba9 +dm-tree==0.1.5 +optax==0.0.1 +rlax @ git+git://github.com/deepmind/rlax.git@870cba1ea8ad36725f4f3a790846298657b6fd4b + +# JAX libraries with GPU support. +jax==0.1.72 +jaxlib==0.1.50 + +# ALE and Gym libraries. +atari-py==0.2.6 +gym==0.13.1 + +# Base dependencies. +absl-py==0.9.0 +numpy==1.18.0 +Pillow==7.1.2 +python-snappy==0.6 + +# Transitive dependencies. +asn1crypto==0.24.0 +cloudpickle==1.2.2 +cryptography==2.1.4 +dataclasses==0.7 +future==0.18.2 +idna==2.6 +keyring==10.6.0 +keyrings.alt==3.0 +opencv-python==4.2.0.34 +opt-einsum==3.2.1 +pycrypto==2.6.1 +pyxdg==0.25 +SecretStorage==2.3.1 +six==1.15.0 +toolz==0.11.1 diff --git a/tandem_dqn/run.sh b/tandem_dqn/run.sh new file mode 100644 index 0000000..b5d8abb --- /dev/null +++ b/tandem_dqn/run.sh @@ -0,0 +1,46 @@ +#!/bin/bash +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script installs tandem_dqn in a clean virtualenv and runs an +# example training loop. It is designed to be run from the parent directory, +# e.g.: +# +# git clone git@github.com:deepmind/deepmind-research.git +# cd deepmind_research +# tandem_dqn/run.sh + +# Fail on errors, display commands. +set -e +set -x + +TMP_DIR=`mktemp -d` + +# Set up environment. +virtualenv --python=python3.6 "${TMP_DIR}/env" +source "${TMP_DIR}/env/bin/activate" + +# Install deps. +pip install --upgrade -r tandem_dqn/requirements.txt + +# Run small-scale example for a few thousand steps. +python -m tandem_dqn.run_tandem \ + --environment_name=space_invaders \ + --replay_capacity=10000 --target_network_update_period=40 \ + --num_iterations=2 --num_train_frames=5000 --num_eval_frames=500 \ + --jax_platform_name=cpu + +# Clean up. +rm -r ${TMP_DIR} +echo "Test run finished." diff --git a/tandem_dqn/run_tandem.py b/tandem_dqn/run_tandem.py new file mode 100644 index 0000000..9112f8f --- /dev/null +++ b/tandem_dqn/run_tandem.py @@ -0,0 +1,356 @@ +# Copyright 2021 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A Tandem DQN agent implemented in JAX, training on Atari.""" + +import collections +import itertools +import sys +import typing + +from absl import app +from absl import flags +from absl import logging +import dm_env +import jax +from jax.config import config +import numpy as np +import optax + +from tandem_dqn import agent as agent_lib +from tandem_dqn import atari_data +from tandem_dqn import gym_atari +from tandem_dqn import losses +from tandem_dqn import networks +from tandem_dqn import parts +from tandem_dqn import processors +from tandem_dqn import replay as replay_lib + +# Relevant flag values are expressed in terms of environment frames. +FLAGS = flags.FLAGS +flags.DEFINE_string('environment_name', 'pong', '') +flags.DEFINE_boolean('use_sticky_actions', False, '') +flags.DEFINE_integer('environment_height', 84, '') +flags.DEFINE_integer('environment_width', 84, '') +flags.DEFINE_integer('replay_capacity', int(1e6), '') +flags.DEFINE_bool('compress_state', True, '') +flags.DEFINE_float('min_replay_capacity_fraction', 0.05, '') +flags.DEFINE_integer('batch_size', 32, '') +flags.DEFINE_integer('max_frames_per_episode', 108000, '') # 30 mins. +flags.DEFINE_integer('num_action_repeats', 4, '') +flags.DEFINE_integer('num_stacked_frames', 4, '') +flags.DEFINE_float('exploration_epsilon_begin_value', 1., '') +flags.DEFINE_float('exploration_epsilon_end_value', 0.01, '') +flags.DEFINE_float('exploration_epsilon_decay_frame_fraction', 0.02, '') +flags.DEFINE_float('eval_exploration_epsilon', 0.01, '') +flags.DEFINE_integer('target_network_update_period', int(1.2e5), '') +flags.DEFINE_float('additional_discount', 0.99, '') +flags.DEFINE_float('max_abs_reward', 1., '') +flags.DEFINE_integer('seed', 1, '') # GPU may introduce nondeterminism. +flags.DEFINE_integer('num_iterations', 200, '') +flags.DEFINE_integer('num_train_frames', int(1e6), '') # Per iteration. +flags.DEFINE_integer('num_eval_frames', int(5e5), '') # Per iteration. +flags.DEFINE_integer('learn_period', 16, '') +flags.DEFINE_string('results_csv_path', '/tmp/results.csv', '') + +# Tandem-specific parameters. + +# Using fixed configs for optimizers: +# RMSProp: lr = 0.00025, eps=0.01 / (32 ** 2) +# ADAM: lr = 0.00005, eps=0.01 / 32 +_OPTIMIZERS = ['rmsprop', 'adam'] +flags.DEFINE_enum('optimizer_active', 'rmsprop', _OPTIMIZERS, '') +flags.DEFINE_enum('optimizer_passive', 'rmsprop', _OPTIMIZERS, '') + +_NETWORKS = ['double_q', 'qr'] +flags.DEFINE_enum('network_active', 'double_q', _NETWORKS, '') +flags.DEFINE_enum('network_passive', 'double_q', _NETWORKS, '') + +_LOSSES = ['double_q', 'double_q_v', 'double_q_p', 'double_q_pv', 'qr', + 'q_regression'] +flags.DEFINE_enum('loss_active', 'double_q', _LOSSES, '') +flags.DEFINE_enum('loss_passive', 'double_q', _LOSSES, '') + +flags.DEFINE_integer('tied_layers', 0, '') + +TandemTuple = agent_lib.TandemTuple + + +def make_optimizer(optimizer_type): + """Constructs optimizer.""" + if optimizer_type == 'rmsprop': + learning_rate = 0.00025 + epsilon = 0.01 / (32**2) + optimizer = optax.rmsprop( + learning_rate=learning_rate, + decay=0.95, + eps=epsilon, + centered=True) + elif optimizer_type == 'adam': + learning_rate = 0.00005 + epsilon = 0.01 / 32 + optimizer = optax.adam( + learning_rate=learning_rate, + eps=epsilon) + else: + raise ValueError('Unknown optimizer "{}"'.format(optimizer_type)) + return optimizer + + +def main(argv): + """Trains Tandem DQN agent on Atari.""" + del argv + logging.info('Tandem DQN on Atari on %s.', + jax.lib.xla_bridge.get_backend().platform) + random_state = np.random.RandomState(FLAGS.seed) + rng_key = jax.random.PRNGKey( + random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64)) + + if FLAGS.results_csv_path: + writer = parts.CsvWriter(FLAGS.results_csv_path) + else: + writer = parts.NullWriter() + + def environment_builder(): + """Creates Atari environment.""" + env = gym_atari.GymAtari( + FLAGS.environment_name, + sticky_actions=FLAGS.use_sticky_actions, + seed=random_state.randint(1, 2**32)) + return gym_atari.RandomNoopsEnvironmentWrapper( + env, + min_noop_steps=1, + max_noop_steps=30, + seed=random_state.randint(1, 2**32), + ) + + env = environment_builder() + + logging.info('Environment: %s', FLAGS.environment_name) + logging.info('Action spec: %s', env.action_spec()) + logging.info('Observation spec: %s', env.observation_spec()) + num_actions = env.action_spec().num_values + + # Check: qr network and qr losses can only be used together. + if ('qr' in FLAGS.network_active) != ('qr' in FLAGS.loss_active): + raise ValueError('Active loss/net must either both use QR, or neither.') + if ('qr' in FLAGS.network_passive) != ('qr' in FLAGS.loss_passive): + raise ValueError('Passive loss/net must either both use QR, or neither.') + network = TandemTuple( + active=networks.make_network(FLAGS.network_active, num_actions), + passive=networks.make_network(FLAGS.network_passive, num_actions), + ) + loss = TandemTuple( + active=losses.make_loss_fn(FLAGS.loss_active, active=True), + passive=losses.make_loss_fn(FLAGS.loss_passive, active=False), + ) + + # Tied layers. + assert 0 <= FLAGS.tied_layers <= 4 + if FLAGS.tied_layers > 0 and (FLAGS.network_passive != 'double_q' + or FLAGS.network_active != 'double_q'): + raise ValueError('Tied layers > 0 is only supported for double_q networks.') + layers = [ + 'sequential/sequential/conv1', + 'sequential/sequential/conv2', + 'sequential/sequential/conv3', + 'sequential/sequential_1/linear1' + ] + tied_layers = set(layers[:FLAGS.tied_layers]) + + def preprocessor_builder(): + return processors.atari( + additional_discount=FLAGS.additional_discount, + max_abs_reward=FLAGS.max_abs_reward, + resize_shape=(FLAGS.environment_height, FLAGS.environment_width), + num_action_repeats=FLAGS.num_action_repeats, + num_pooled_frames=2, + zero_discount_on_life_loss=True, + num_stacked_frames=FLAGS.num_stacked_frames, + grayscaling=True, + ) + + # Create sample network input from sample preprocessor output. + sample_processed_timestep = preprocessor_builder()(env.reset()) + sample_processed_timestep = typing.cast(dm_env.TimeStep, + sample_processed_timestep) + sample_network_input = sample_processed_timestep.observation + assert sample_network_input.shape == (FLAGS.environment_height, + FLAGS.environment_width, + FLAGS.num_stacked_frames) + + exploration_epsilon_schedule = parts.LinearSchedule( + begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity * + FLAGS.num_action_repeats), + decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction * + FLAGS.num_iterations * FLAGS.num_train_frames), + begin_value=FLAGS.exploration_epsilon_begin_value, + end_value=FLAGS.exploration_epsilon_end_value) + + if FLAGS.compress_state: + + def encoder(transition): + return transition._replace( + s_tm1=replay_lib.compress_array(transition.s_tm1), + s_t=replay_lib.compress_array(transition.s_t)) + + def decoder(transition): + return transition._replace( + s_tm1=replay_lib.uncompress_array(transition.s_tm1), + s_t=replay_lib.uncompress_array(transition.s_t)) + else: + encoder = None + decoder = None + + replay_structure = replay_lib.Transition( + s_tm1=None, + a_tm1=None, + r_t=None, + discount_t=None, + s_t=None, + a_t=None, + mc_return_tm1=None, + ) + + replay = replay_lib.TransitionReplay(FLAGS.replay_capacity, replay_structure, + random_state, encoder, decoder) + + optimizer = TandemTuple( + active=make_optimizer(FLAGS.optimizer_active), + passive=make_optimizer(FLAGS.optimizer_passive), + ) + + train_rng_key, eval_rng_key = jax.random.split(rng_key) + + train_agent = agent_lib.TandemDqn( + preprocessor=preprocessor_builder(), + sample_network_input=sample_network_input, + network=network, + optimizer=optimizer, + loss=loss, + transition_accumulator=replay_lib.TransitionAccumulatorWithMCReturn(), + replay=replay, + batch_size=FLAGS.batch_size, + exploration_epsilon=exploration_epsilon_schedule, + min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction, + learn_period=FLAGS.learn_period, + target_network_update_period=FLAGS.target_network_update_period, + tied_layers=tied_layers, + rng_key=train_rng_key, + ) + eval_agent_active = parts.EpsilonGreedyActor( + preprocessor=preprocessor_builder(), + network=network.active, + exploration_epsilon=FLAGS.eval_exploration_epsilon, + rng_key=eval_rng_key) + eval_agent_passive = parts.EpsilonGreedyActor( + preprocessor=preprocessor_builder(), + network=network.passive, + exploration_epsilon=FLAGS.eval_exploration_epsilon, + rng_key=eval_rng_key) + + # Set up checkpointing. + checkpoint = parts.NullCheckpoint() + + state = checkpoint.state + state.iteration = 0 + state.train_agent = train_agent + state.eval_agent_active = eval_agent_active + state.eval_agent_passive = eval_agent_passive + state.random_state = random_state + state.writer = writer + if checkpoint.can_be_restored(): + checkpoint.restore() + + # Run single iteration of training or evaluation. + def run_iteration(agent, env, num_frames): + seq = parts.run_loop(agent, env, FLAGS.max_frames_per_episode) + seq_truncated = itertools.islice(seq, num_frames) + trackers = parts.make_default_trackers(agent) + return parts.generate_statistics(trackers, seq_truncated) + + def eval_log_output(eval_stats, suffix): + human_normalized_score = atari_data.get_human_normalized_score( + FLAGS.environment_name, eval_stats['episode_return']) + capped_human_normalized_score = np.amin([1., human_normalized_score]) + return [ + ('eval_episode_return_' + suffix, + eval_stats['episode_return'], '% 2.2f'), + ('eval_num_episodes_' + suffix, + eval_stats['num_episodes'], '%3d'), + ('eval_frame_rate_' + suffix, + eval_stats['step_rate'], '%4.0f'), + ('normalized_return_' + suffix, + human_normalized_score, '%.3f'), + ('capped_normalized_return_' + suffix, + capped_human_normalized_score, '%.3f'), + ('human_gap_' + suffix, + 1. - capped_human_normalized_score, '%.3f'), + ] + + while state.iteration <= FLAGS.num_iterations: + # New environment for each iteration to allow for determinism if preempted. + env = environment_builder() + + # Set agent to train active and passive nets on each learning step. + train_agent.set_training_mode('active_passive') + + logging.info('Training iteration %d.', state.iteration) + num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames + train_stats = run_iteration(train_agent, env, num_train_frames) + + logging.info('Evaluation iteration %d - active agent.', state.iteration) + eval_agent_active.network_params = train_agent.online_params.active + eval_stats_active = run_iteration(eval_agent_active, env, + FLAGS.num_eval_frames) + + logging.info('Evaluation iteration %d - passive agent.', state.iteration) + eval_agent_passive.network_params = train_agent.online_params.passive + eval_stats_passive = run_iteration(eval_agent_passive, env, + FLAGS.num_eval_frames) + + # Logging and checkpointing. + agent_logs = [ + 'loss_active', + 'loss_passive', + 'frac_diff_argmax', + 'mc_error_active', + 'mc_error_passive', + 'mc_error_abs_active', + 'mc_error_abs_passive', + ] + log_output = ( + eval_log_output(eval_stats_active, 'active') + + eval_log_output(eval_stats_passive, 'passive') + + [('iteration', state.iteration, '%3d'), + ('frame', state.iteration * FLAGS.num_train_frames, '%5d'), + ('train_episode_return', train_stats['episode_return'], '% 2.2f'), + ('train_num_episodes', train_stats['num_episodes'], '%3d'), + ('train_frame_rate', train_stats['step_rate'], '%4.0f'), + ] + + [(k, train_stats[k], '% 2.2f') for k in agent_logs] + ) + log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output) + logging.info(log_output_str) + writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output)) + state.iteration += 1 + checkpoint.save() + + writer.close() + + +if __name__ == '__main__': + config.update('jax_platform_name', 'gpu') # Default to GPU. + config.update('jax_numpy_rank_promotion', 'raise') + config.config_with_absl() + app.run(main)