Added tandem_dqn to deepmind-research repo

PiperOrigin-RevId: 405653277
This commit is contained in:
Georg Ostrovski
2021-10-26 16:30:00 +01:00
committed by Saran Tunyasuvunakool
parent f6da52cb38
commit 5f6b11299e
14 changed files with 2983 additions and 0 deletions
+1
View File
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
## Projects
* [The Difficulty of Passive Learning in Deep Reinforcement Learning](tandem_dqn), NeurIPS 2021
* [Skilful precipitation nowcasting using deep generative models of radar](nowcasting), Nature 2021
* [Compute-Aided Design as Language](cadl)
* [Encoders and ensembles for continual learning](continual_learning)
+78
View File
@@ -0,0 +1,78 @@
# Tandem DQN
This repository provides an implementation of the Tandem DQN agent and
main experiments presented in the paper
"The Difficulty of Passive Learning in Deep Reinforcement Learning"
(Georg Ostrovski, Pablo Samuel Castro and Will Dabney, 2021).
The code is a modified fork of the [DoubleDQN](https://arxiv.org/abs/1509.06461)
agent in the [DQN Zoo](https://github.com/deepmind/dqn_zoo) agent collection.
## Quick start
This code can be run with a regular CPU setup, but will only be reasonably fast
(with the given configuration) if run with a GPU accelerator, in which case
you will need an NVIDIA GPU with recent CUDA drivers.
For installation, run
```shell
git clone git@github.com:deepmind/deepmind-research.git
virtualenv --python=python3.6 "tandem"
source tandem/bin/activate
cd deepmind_research
pip install -r tandem_dqn/requirements.txt
```
The code has only been tested with Python 3.6.14
## Running the experiments presented in the paper
To execute the vanilla Tandem DQN experiment on the Pong environment, run:
```bash
python -m "tandem_dqn.run_tandem" --environment_name=pong --seed=42
```
A number of flags can be specified to customize execution and run various of
the presented experimental variations:
* `--use_sticky_actions`
(values: `True`, `False`; default: `False`):
Use "sticky actions" variant
([Machado et al 2017](https://arxiv.org/abs/1709.06009)) of the Atari environment.
* `--network_active`, `--network_passive`
(values: `double_q`, `qr`; default: `double_q`):
Whether to use the DoubleDQN or QR-DQN
([Dabney et al 2017](https://arxiv.org/abs/1710.10044)) network architecture
for active and passive agents (can be set independently). Note this value
needs to be compatible with the chosen loss (see below).
* `--loss_active`, `--loss_passive`
(values: `double_q`, `double_q_v`, `double_q_p`, `double_q_pv`, `qr`,
`q_regression`; default: `double_q`)
Which loss to use for active and passive agent training. The losses are:
* `double_q`: regular Double-Q-Learning loss.
* `double_q_v`: Double-Q-Learning with bootstrap target _values_ provided
by the respective _other_ agent's target network.
* `double_q_p`: Double-Q-Learning with bootstrap target _policy_ (argmax)
provided by the respective _other_ agent's online network.
* `double_q_pv`: Double-Q-Learning with both boostrap target values and
policy provided by the respective _other_ agent's target and online networks.
* `qr`: Quantile Regression Q-Learning loss
([Dabney et al 2017](https://arxiv.org/abs/1710.10044)).
* `q_regression`: Supervised regression loss towards the respective other
agent's online network's output values.
* `--optimizer_active`, `--optimizer_passive`
(values: `adam`, `rmsprop`; default: `rmsprop` for both):
Which optimization algorithm to use for the active and passive network training
(can be set independently).
* `--exploration_epsilon_end_value`
(values in `[0, 1]`; default: `0.01`):
Value of epsilon parameter in active agent's epsilong-greedy policy after an
initial decay phase.
+378
View File
@@ -0,0 +1,378 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tandem DQN agent class."""
import typing
from typing import Any, Callable, Mapping, Set, Text
from absl import logging
import dm_env
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import optax
import rlax
from tandem_dqn import losses
from tandem_dqn import parts
from tandem_dqn import processors
from tandem_dqn import replay as replay_lib
class TandemTuple(typing.NamedTuple):
active: Any
passive: Any
def tandem_map(fn: Callable[..., Any], *args):
return TandemTuple(
active=fn(*[a.active for a in args]),
passive=fn(*[a.passive for a in args]))
def replace_module_params(source, target, modules):
"""Replace selected module params in target by corresponding source values."""
source, _ = hk.data_structures.partition(
lambda module, name, value: module in modules,
source)
return hk.data_structures.merge(target, source)
class TandemDqn(parts.Agent):
"""Tandem DQN agent."""
def __init__(
self,
preprocessor: processors.Processor,
sample_network_input: jnp.ndarray,
network: TandemTuple,
optimizer: TandemTuple,
loss: TandemTuple,
transition_accumulator: Any,
replay: replay_lib.TransitionReplay,
batch_size: int,
exploration_epsilon: Callable[[int], float],
min_replay_capacity_fraction: float,
learn_period: int,
target_network_update_period: int,
tied_layers: Set[str],
rng_key: parts.PRNGKey,
):
self._preprocessor = preprocessor
self._replay = replay
self._transition_accumulator = transition_accumulator
self._batch_size = batch_size
self._exploration_epsilon = exploration_epsilon
self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
self._learn_period = learn_period
self._target_network_update_period = target_network_update_period
# Initialize network parameters and optimizer.
self._rng_key, network_rng_key_active, network_rng_key_passive = (
jax.random.split(rng_key, 3))
active_params = network.active.init(
network_rng_key_active, sample_network_input[None, ...])
passive_params = network.passive.init(
network_rng_key_passive, sample_network_input[None, ...])
self._online_params = TandemTuple(
active=active_params, passive=passive_params)
self._target_params = self._online_params
self._opt_state = tandem_map(
lambda optim, params: optim.init(params),
optimizer, self._online_params)
# Other agent state: last action, frame count, etc.
self._action = None
self._frame_t = -1 # Current frame index.
# Stats.
stats = [
'loss_active',
'loss_passive',
'frac_diff_argmax',
'mc_error_active',
'mc_error_passive',
'mc_error_abs_active',
'mc_error_abs_passive',
]
self._statistics = {k: np.nan for k in stats}
# Define jitted loss, update, and policy functions here instead of as
# class methods, to emphasize that these are meant to be pure functions
# and should not access the agent object's state via `self`.
def network_outputs(rng_key, online_params, target_params, transitions):
"""Compute all potentially needed outputs of active and passive net."""
_, *apply_keys = jax.random.split(rng_key, 4)
outputs_tm1 = tandem_map(
lambda net, param: net.apply(param, apply_keys[0], transitions.s_tm1),
network, online_params)
outputs_t = tandem_map(
lambda net, param: net.apply(param, apply_keys[1], transitions.s_t),
network, online_params)
outputs_target_t = tandem_map(
lambda net, param: net.apply(param, apply_keys[2], transitions.s_t),
network, target_params)
return outputs_tm1, outputs_t, outputs_target_t
# Helper functions to define active and passive losses.
# Active and passive losses are allowed to depend on all active and passive
# outputs, but stop-gradient is used to prevent gradients from flowing
# from active loss to passive network params and vice versa.
def sg_active(x):
return TandemTuple(
active=jax.lax.stop_gradient(x.active), passive=x.passive)
def sg_passive(x):
return TandemTuple(
active=x.active, passive=jax.lax.stop_gradient(x.passive))
def compute_loss(online_params, target_params, transitions, rng_key):
rng_key, apply_key = jax.random.split(rng_key)
outputs_tm1, outputs_t, outputs_target_t = network_outputs(
apply_key, online_params, target_params, transitions)
_, loss_key_active, loss_key_passive = jax.random.split(rng_key, 3)
loss_active = loss.active(
sg_passive(outputs_tm1), sg_passive(outputs_t), outputs_target_t,
transitions, loss_key_active)
loss_passive = loss.passive(
sg_active(outputs_tm1), sg_active(outputs_t), outputs_target_t,
transitions, loss_key_passive)
# Logging stuff.
a_tm1 = transitions.a_tm1
mc_return_tm1 = transitions.mc_return_tm1
q_values = TandemTuple(
active=outputs_tm1.active.q_values,
passive=outputs_tm1.passive.q_values)
mc_error = jax.tree_map(
lambda q: losses.batch_mc_learning(q, a_tm1, mc_return_tm1),
q_values)
mc_error_abs = jax.tree_map(jnp.abs, mc_error)
q_argmax = jax.tree_map(lambda q: jnp.argmax(q, axis=-1), q_values)
argmax_diff = jnp.not_equal(q_argmax.active, q_argmax.passive)
batch_mean = lambda x: jnp.mean(x, axis=0)
logs = {
'loss_active': loss_active,
'loss_passive': loss_passive
}
logs.update(jax.tree_map(batch_mean, {
'frac_diff_argmax': argmax_diff,
'mc_error_active': mc_error.active,
'mc_error_passive': mc_error.passive,
'mc_error_abs_active': mc_error_abs.active,
'mc_error_abs_passive': mc_error_abs.passive,
}))
return loss_active + loss_passive, logs
def optim_update(optim, online_params, d_loss_d_params, opt_state):
updates, new_opt_state = optim.update(d_loss_d_params, opt_state)
new_online_params = optax.apply_updates(online_params, updates)
return new_opt_state, new_online_params
def compute_loss_grad(rng_key, online_params, target_params, transitions):
rng_key, grad_key = jax.random.split(rng_key)
(_, logs), d_loss_d_params = jax.value_and_grad(
compute_loss, has_aux=True)(
online_params, target_params, transitions, grad_key)
return rng_key, logs, d_loss_d_params
def update_active(rng_key, opt_state, online_params, target_params,
transitions):
"""Applies learning update for active network only."""
rng_key, logs, d_loss_d_params = compute_loss_grad(
rng_key, online_params, target_params, transitions)
new_opt_state_active, new_online_params_active = optim_update(
optimizer.active, online_params.active, d_loss_d_params.active,
opt_state.active)
new_opt_state = opt_state._replace(
active=new_opt_state_active)
new_online_params = online_params._replace(
active=new_online_params_active)
return rng_key, new_opt_state, new_online_params, logs
self._update_active = jax.jit(update_active)
def update_passive(rng_key, opt_state, online_params, target_params,
transitions):
"""Applies learning update for passive network only."""
rng_key, logs, d_loss_d_params = compute_loss_grad(
rng_key, online_params, target_params, transitions)
new_opt_state_passive, new_online_params_passive = optim_update(
optimizer.passive, online_params.passive, d_loss_d_params.passive,
opt_state.passive)
new_opt_state = opt_state._replace(
passive=new_opt_state_passive)
new_online_params = online_params._replace(
passive=new_online_params_passive)
return rng_key, new_opt_state, new_online_params, logs
self._update_passive = jax.jit(update_passive)
def update_active_passive(rng_key, opt_state, online_params,
target_params, transitions):
"""Applies learning update for both active & passive networks."""
rng_key, logs, d_loss_d_params = compute_loss_grad(
rng_key, online_params, target_params, transitions)
new_opt_state_active, new_online_params_active = optim_update(
optimizer.active, online_params.active, d_loss_d_params.active,
opt_state.active)
new_opt_state_passive, new_online_params_passive = optim_update(
optimizer.passive, online_params.passive, d_loss_d_params.passive,
opt_state.passive)
new_opt_state = TandemTuple(active=new_opt_state_active,
passive=new_opt_state_passive)
new_online_params = TandemTuple(active=new_online_params_active,
passive=new_online_params_passive)
return rng_key, new_opt_state, new_online_params, logs
self._update_active_passive = jax.jit(update_active_passive)
self._update = None # set_training_mode needs to be called to set this.
def sync_tied_layers(online_params):
"""Set tied layer params of passive to respective values of active."""
new_online_params_passive = replace_module_params(
source=online_params.active, target=online_params.passive,
modules=tied_layers)
return online_params._replace(passive=new_online_params_passive)
self._sync_tied_layers = jax.jit(sync_tied_layers)
def select_action(rng_key, network_params, s_t, exploration_epsilon):
"""Samples action from eps-greedy policy wrt Q-values at given state."""
rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
q_t = network.active.apply(network_params, apply_key,
s_t[None, ...]).q_values[0]
a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon)
return rng_key, a_t
self._select_action = jax.jit(select_action)
def step(self, timestep: dm_env.TimeStep) -> parts.Action:
"""Selects action given timestep and potentially learns."""
self._frame_t += 1
timestep = self._preprocessor(timestep)
if timestep is None: # Repeat action.
action = self._action
else:
action = self._action = self._act(timestep)
for transition in self._transition_accumulator.step(timestep, action):
self._replay.add(transition)
if self._replay.size < self._min_replay_capacity:
return action
if self._frame_t % self._learn_period == 0:
self._learn()
if self._frame_t % self._target_network_update_period == 0:
self._target_params = self._online_params
return action
def reset(self) -> None:
"""Resets the agent's episodic state such as frame stack and action repeat.
This method should be called at the beginning of every episode.
"""
self._transition_accumulator.reset()
processors.reset(self._preprocessor)
self._action = None
def _act(self, timestep) -> parts.Action:
"""Selects action given timestep, according to epsilon-greedy policy."""
s_t = timestep.observation
network_params = self._online_params.active
self._rng_key, a_t = self._select_action(
self._rng_key, network_params, s_t, self.exploration_epsilon)
return parts.Action(jax.device_get(a_t))
def _learn(self) -> None:
"""Samples a batch of transitions from replay and learns from it."""
logging.log_first_n(logging.INFO, 'Begin learning', 1)
transitions = self._replay.sample(self._batch_size)
self._rng_key, self._opt_state, self._online_params, logs = self._update(
self._rng_key,
self._opt_state,
self._online_params,
self._target_params,
transitions,
)
self._online_params = self._sync_tied_layers(self._online_params)
self._statistics.update(jax.device_get(logs))
def set_training_mode(self, mode: str):
"""Sets training mode to one of 'active', 'passive', or 'active_passive'."""
if mode == 'active':
self._update = self._update_active
elif mode == 'passive':
self._update = self._update_passive
elif mode == 'active_passive':
self._update = self._update_active_passive
@property
def online_params(self) -> TandemTuple:
"""Returns current parameters of Q-network."""
return self._online_params
@property
def statistics(self) -> Mapping[Text, float]:
"""Returns current agent statistics as a dictionary."""
# Check for DeviceArrays in values as this can be very slow.
assert all(
not isinstance(x, jnp.DeviceArray) for x in self._statistics.values())
return self._statistics
@property
def exploration_epsilon(self) -> float:
"""Returns epsilon value currently used by (eps-greedy) behavior policy."""
return self._exploration_epsilon(self._frame_t)
def get_state(self) -> Mapping[Text, Any]:
"""Retrieves agent state as a dictionary (e.g. for serialization)."""
state = {
'rng_key': self._rng_key,
'frame_t': self._frame_t,
'opt_state_active': self._opt_state.active,
'online_params_active': self._online_params.active,
'target_params_active': self._target_params.active,
'opt_state_passive': self._opt_state.passive,
'online_params_passive': self._online_params.passive,
'target_params_passive': self._target_params.passive,
'replay': self._replay.get_state(),
}
return state
def set_state(self, state: Mapping[Text, Any]) -> None:
"""Sets agent state from a (potentially de-serialized) dictionary."""
self._rng_key = state['rng_key']
self._frame_t = state['frame_t']
self._opt_state = TandemTuple(
active=jax.device_put(state['opt_state_active']),
passive=jax.device_put(state['opt_state_passive']))
self._online_params = TandemTuple(
active=jax.device_put(state['online_params_active']),
passive=jax.device_put(state['online_params_passive']))
self._target_params = TandemTuple(
active=jax.device_put(state['target_params_active']),
passive=jax.device_put(state['target_params_passive']))
self._replay.set_state(state['replay'])
+110
View File
@@ -0,0 +1,110 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to compute human-normalized Atari scores.
The data used in this module is human and random performance data on Atari-57.
It comprises of evaluation scores (undiscounted returns), each averaged
over at least 3 episode runs, on each of the 57 Atari games. Each episode begins
with the environment already stepped with a uniform random number (between 1 and
30 inclusive) of noop actions.
The two agents are:
* 'random' (agent choosing its actions uniformly randomly on each step)
* 'human' (professional human game tester)
Scores are obtained by averaging returns over the episodes played by each agent,
with episode length capped to 108,000 frames (i.e. timeout after 30 minutes).
The term 'human-normalized' here means a linear per-game transformation of
a game score in such a way that 0 corresponds to random performance and 1
corresponds to human performance.
"""
import math
# Game: score-tuple dictionary. Each score tuple contains
# 0: score random (float) and 1: score human (float).
_ATARI_DATA = {
'alien': (227.8, 7127.7),
'amidar': (5.8, 1719.5),
'assault': (222.4, 742.0),
'asterix': (210.0, 8503.3),
'asteroids': (719.1, 47388.7),
'atlantis': (12850.0, 29028.1),
'bank_heist': (14.2, 753.1),
'battle_zone': (2360.0, 37187.5),
'beam_rider': (363.9, 16926.5),
'berzerk': (123.7, 2630.4),
'bowling': (23.1, 160.7),
'boxing': (0.1, 12.1),
'breakout': (1.7, 30.5),
'centipede': (2090.9, 12017.0),
'chopper_command': (811.0, 7387.8),
'crazy_climber': (10780.5, 35829.4),
'defender': (2874.5, 18688.9),
'demon_attack': (152.1, 1971.0),
'double_dunk': (-18.6, -16.4),
'enduro': (0.0, 860.5),
'fishing_derby': (-91.7, -38.7),
'freeway': (0.0, 29.6),
'frostbite': (65.2, 4334.7),
'gopher': (257.6, 2412.5),
'gravitar': (173.0, 3351.4),
'hero': (1027.0, 30826.4),
'ice_hockey': (-11.2, 0.9),
'jamesbond': (29.0, 302.8),
'kangaroo': (52.0, 3035.0),
'krull': (1598.0, 2665.5),
'kung_fu_master': (258.5, 22736.3),
'montezuma_revenge': (0.0, 4753.3),
'ms_pacman': (307.3, 6951.6),
'name_this_game': (2292.3, 8049.0),
'phoenix': (761.4, 7242.6),
'pitfall': (-229.4, 6463.7),
'pong': (-20.7, 14.6),
'private_eye': (24.9, 69571.3),
'qbert': (163.9, 13455.0),
'riverraid': (1338.5, 17118.0),
'road_runner': (11.5, 7845.0),
'robotank': (2.2, 11.9),
'seaquest': (68.4, 42054.7),
'skiing': (-17098.1, -4336.9),
'solaris': (1236.3, 12326.7),
'space_invaders': (148.0, 1668.7),
'star_gunner': (664.0, 10250.0),
'surround': (-10.0, 6.5),
'tennis': (-23.8, -8.3),
'time_pilot': (3568.0, 5229.2),
'tutankham': (11.4, 167.6),
'up_n_down': (533.4, 11693.2),
'venture': (0.0, 1187.5),
# Note the random agent score on Video Pinball is sometimes greater than the
# human score under other evaluation methods.
'video_pinball': (16256.9, 17667.9),
'wizard_of_wor': (563.5, 4756.5),
'yars_revenge': (3092.9, 54576.9),
'zaxxon': (32.5, 9173.3),
}
_RANDOM_COL = 0
_HUMAN_COL = 1
ATARI_GAMES = tuple(sorted(_ATARI_DATA.keys()))
def get_human_normalized_score(game: str, raw_score: float) -> float:
"""Converts game score to human-normalized score."""
game_scores = _ATARI_DATA.get(game, (math.nan, math.nan))
random, human = game_scores[_RANDOM_COL], game_scores[_HUMAN_COL]
return (raw_score - random) / (human - random)
+217
View File
@@ -0,0 +1,217 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""dm_env environment wrapper around Gym Atari configured to be like Xitari.
Gym Atari is built on the Arcade Learning Environment (ALE), whereas Xitari is
an old fork of the ALE.
"""
# pylint: disable=g-bad-import-order
from typing import Optional, Tuple
import atari_py # pylint: disable=unused-import for gym to load Atari games.
import dm_env
from dm_env import specs
import gym
import numpy as np
from tandem_dqn import atari_data
_GYM_ID_SUFFIX = '-xitari-v1'
_SA_SUFFIX = '-sa'
def _game_id(game, sticky_actions):
return game + (_SA_SUFFIX if sticky_actions else '') + _GYM_ID_SUFFIX
def _register_atari_environments():
"""Registers Atari environments in Gym to be as similar to Xitari as possible.
Main difference from PongNoFrameSkip-v4, etc. is max_episode_steps is unset
and only the usual 57 Atari games are registered.
Additionally, sticky-actions variants of the environments are registered
with an '-sa' suffix.
"""
for sticky_actions in [False, True]:
for game in atari_data.ATARI_GAMES:
repeat_action_probability = 0.25 if sticky_actions else 0.0
gym.envs.registration.register(
id=_game_id(game, sticky_actions),
entry_point='gym.envs.atari:AtariEnv',
kwargs={ # Explicitly set all known arguments.
'game': game,
'mode': None, # Not necessarily the same as 0.
'difficulty': None, # Not necessarily the same as 0.
'obs_type': 'image',
'frameskip': 1, # Get every frame.
'repeat_action_probability': repeat_action_probability,
'full_action_space': False,
},
max_episode_steps=None, # No time limit, handled in run loop.
nondeterministic=False, # Xitari is deterministic.
)
_register_atari_environments()
class GymAtari(dm_env.Environment):
"""Gym Atari with a `dm_env.Environment` interface."""
def __init__(self, game, sticky_actions, seed):
self._gym_env = gym.make(_game_id(game, sticky_actions))
self._gym_env.seed(seed)
self._start_of_episode = True
def reset(self) -> dm_env.TimeStep:
"""Resets the environment and starts a new episode."""
observation = self._gym_env.reset()
lives = np.int32(self._gym_env.ale.lives())
timestep = dm_env.restart((observation, lives))
self._start_of_episode = False
return timestep
def step(self, action: np.int32) -> dm_env.TimeStep:
"""Updates the environment given an action and returns a timestep."""
# If the previous timestep was LAST then we call reset() on the Gym
# environment, otherwise step(). Although Gym environments allow you to step
# through episode boundaries (similar to dm_env) they emit a warning.
if self._start_of_episode:
step_type = dm_env.StepType.FIRST
observation = self._gym_env.reset()
discount = None
reward = None
done = False
else:
observation, reward, done, info = self._gym_env.step(action)
if done:
assert 'TimeLimit.truncated' not in info, 'Should never truncate.'
step_type = dm_env.StepType.LAST
discount = 0.
else:
step_type = dm_env.StepType.MID
discount = 1.
lives = np.int32(self._gym_env.ale.lives())
timestep = dm_env.TimeStep(
step_type=step_type,
observation=(observation, lives),
reward=reward,
discount=discount,
)
self._start_of_episode = done
return timestep
def observation_spec(self) -> Tuple[specs.Array, specs.Array]:
space = self._gym_env.observation_space
return (specs.Array(shape=space.shape, dtype=space.dtype, name='rgb'),
specs.Array(shape=(), dtype=np.int32, name='lives'))
def action_spec(self) -> specs.DiscreteArray:
space = self._gym_env.action_space
return specs.DiscreteArray(
num_values=space.n, dtype=np.int32, name='action')
def close(self):
self._gym_env.close()
class RandomNoopsEnvironmentWrapper(dm_env.Environment):
"""Adds a random number of noop actions at the beginning of each episode."""
def __init__(self,
environment: dm_env.Environment,
max_noop_steps: int,
min_noop_steps: int = 0,
noop_action: int = 0,
seed: Optional[int] = None):
"""Initializes the random noops environment wrapper."""
self._environment = environment
if max_noop_steps < min_noop_steps:
raise ValueError('max_noop_steps must be greater or equal min_noop_steps')
self._min_noop_steps = min_noop_steps
self._max_noop_steps = max_noop_steps
self._noop_action = noop_action
self._rng = np.random.RandomState(seed)
def reset(self):
"""Begins new episode.
This method resets the wrapped environment and applies a random number
of noop actions before returning the last resulting observation
as the first episode timestep. Intermediate timesteps emitted by the inner
environment (including all rewards and discounts) are discarded.
Returns:
First episode timestep corresponding to the timestep after a random number
of noop actions are applied to the inner environment.
Raises:
RuntimeError: if an episode end occurs while the inner environment
is being stepped through with the noop action.
"""
return self._apply_random_noops(initial_timestep=self._environment.reset())
def step(self, action):
"""Steps environment given action.
If beginning a new episode then random noops are applied as in `reset()`.
Args:
action: action to pass to environment conforming to action spec.
Returns:
`Timestep` from the inner environment unless beginning a new episode, in
which case this is the timestep after a random number of noop actions
are applied to the inner environment.
"""
timestep = self._environment.step(action)
if timestep.first():
return self._apply_random_noops(initial_timestep=timestep)
else:
return timestep
def _apply_random_noops(self, initial_timestep):
assert initial_timestep.first()
num_steps = self._rng.randint(self._min_noop_steps,
self._max_noop_steps + 1)
timestep = initial_timestep
for _ in range(num_steps):
timestep = self._environment.step(self._noop_action)
if timestep.last():
raise RuntimeError('Episode ended while applying %s noop actions.' %
num_steps)
# We make sure to return a FIRST timestep, i.e. discard rewards & discounts.
return dm_env.restart(timestep.observation)
## All methods except for reset and step redirect to the underlying env.
def observation_spec(self):
return self._environment.observation_spec()
def action_spec(self):
return self._environment.action_spec()
def reward_spec(self):
return self._environment.reward_spec()
def discount_spec(self):
return self._environment.discount_spec()
def close(self):
return self._environment.close()
+190
View File
@@ -0,0 +1,190 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Losses for TandemDQN."""
from typing import Any, Callable
import chex
import jax
import jax.numpy as jnp
import rlax
from tandem_dqn import networks
# Batch variants of double_q_learning and SARSA.
batch_double_q_learning = jax.vmap(rlax.double_q_learning)
batch_sarsa_learning = jax.vmap(rlax.sarsa)
# Batch variant of quantile_q_learning with fixed tau input across batch.
batch_quantile_q_learning = jax.vmap(
rlax.quantile_q_learning, in_axes=(0, None, 0, 0, 0, 0, 0, None))
def _mc_learning(
q_tm1: chex.Array,
a_tm1: chex.Numeric,
mc_return_tm1: chex.Array,
) -> chex.Numeric:
"""Calculates the MC return error."""
chex.assert_rank([q_tm1, a_tm1], [1, 0])
chex.assert_type([q_tm1, a_tm1], [float, int])
return mc_return_tm1 - q_tm1[a_tm1]
# Batch variant of MC learning.
batch_mc_learning = jax.vmap(_mc_learning)
def _qr_loss(q_tm1, q_t, q_target_t, transitions, rng_key):
"""Calculates QR-Learning loss from network outputs and transitions."""
del q_t, rng_key # Unused.
# Compute Q value distributions.
huber_param = 1.
quantiles = networks.make_quantiles()
losses = batch_quantile_q_learning(
q_tm1.q_dist,
quantiles,
transitions.a_tm1,
transitions.r_t,
transitions.discount_t,
q_target_t.q_dist, # No double Q-learning here.
q_target_t.q_dist,
huber_param,
)
loss = jnp.mean(losses)
return loss
def _sarsa_loss(q_tm1, q_t, transitions, rng_key):
"""Calculates SARSA loss from network outputs and transitions."""
del rng_key # Unused.
grad_error_bound = 1. / 32
td_errors = batch_sarsa_learning(
q_tm1.q_values,
transitions.a_tm1,
transitions.r_t,
transitions.discount_t,
q_t.q_values,
transitions.a_t
)
td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound)
losses = rlax.l2_loss(td_errors)
loss = jnp.mean(losses)
return loss
def _mc_loss(q_tm1, transitions, rng_key):
"""Calculates Monte-Carlo return loss, i.e. regression towards MC return."""
del rng_key # Unused.
errors = batch_mc_learning(q_tm1.q_values, transitions.a_tm1,
transitions.mc_return_tm1)
loss = jnp.mean(rlax.l2_loss(errors))
return loss
def _double_q_loss(q_tm1, q_t, q_target_t, transitions, rng_key):
"""Calculates Double Q-Learning loss from network outputs and transitions."""
del rng_key # Unused.
grad_error_bound = 1. / 32
td_errors = batch_double_q_learning(
q_tm1.q_values,
transitions.a_tm1,
transitions.r_t,
transitions.discount_t,
q_target_t.q_values,
q_t.q_values,
)
td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound)
losses = rlax.l2_loss(td_errors)
loss = jnp.mean(losses)
return loss
def _q_regression_loss(q_tm1, q_tm1_target):
"""Loss for regression of all action values towards targets."""
errors = q_tm1.q_values - jax.lax.stop_gradient(q_tm1_target.q_values)
loss = jnp.mean(rlax.l2_loss(errors))
return loss
def make_loss_fn(loss_type: str, active: bool) -> Callable[..., Any]:
"""Create active or passive loss function of given type."""
if active:
primary = lambda x: x.active
secondary = lambda x: x.passive
else:
primary = lambda x: x.passive
secondary = lambda x: x.active
def sarsa_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
"""SARSA loss using own networks."""
del q_t # Unused.
return _sarsa_loss(primary(q_tm1), primary(q_target_t), transitions,
rng_key)
def mc_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
"""MonteCarlo loss."""
del q_t, q_target_t
return _mc_loss(primary(q_tm1), transitions, rng_key)
def double_q_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
"""Regular DoubleQ loss using own networks."""
return _double_q_loss(primary(q_tm1), primary(q_t), primary(q_target_t),
transitions, rng_key)
def double_q_loss_v_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
"""DoubleQ loss using other network's (target) value function."""
return _double_q_loss(primary(q_tm1), primary(q_t), secondary(q_target_t),
transitions, rng_key)
def double_q_loss_p_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
"""DoubleQ loss using other network's (online) argmax policy."""
return _double_q_loss(primary(q_tm1), secondary(q_t), primary(q_target_t),
transitions, rng_key)
def double_q_loss_pv_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
"""DoubleQ loss using other network's argmax policy & target value fn."""
return _double_q_loss(primary(q_tm1), secondary(q_t), secondary(q_target_t),
transitions, rng_key)
# Pure regression.
def q_regression_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
"""Pure regression of q_tm1(self) towards q_tm1(other)."""
del q_t, q_target_t, transitions, rng_key # Unused.
return _q_regression_loss(primary(q_tm1), secondary(q_tm1))
# QR loss.
def qr_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
"""QR-Q loss using own networks."""
return _qr_loss(primary(q_tm1), primary(q_t), primary(q_target_t),
transitions, rng_key)
if loss_type == 'double_q':
return double_q_loss_fn
elif loss_type == 'sarsa':
return sarsa_loss_fn
elif loss_type == 'mc_return':
return mc_loss_fn
elif loss_type == 'double_q_v':
return double_q_loss_v_fn
elif loss_type == 'double_q_p':
return double_q_loss_p_fn
elif loss_type == 'double_q_pv':
return double_q_loss_pv_fn
elif loss_type == 'q_regression':
return q_regression_loss_fn
elif loss_type == 'qr':
return qr_loss_fn
else:
raise ValueError('Unknown loss "{}"'.format(loss_type))
+95
View File
@@ -0,0 +1,95 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for Tandem losses."""
from absl.testing import absltest
from absl.testing import parameterized
import chex
import jax
from jax.config import config
import numpy as np
from tandem_dqn import agent
from tandem_dqn import losses
from tandem_dqn import networks
from tandem_dqn import replay
def make_tandem_qvals():
return agent.TandemTuple(
active=networks.QNetworkOutputs(3. * np.ones((3, 5), np.float32)),
passive=networks.QNetworkOutputs(2. * np.ones((3, 5), np.float32))
)
def make_transition():
return replay.Transition(
s_tm1=np.zeros(3), a_tm1=np.ones(3, np.int32), r_t=5. * np.ones(3),
discount_t=0.9 * np.ones(3), s_t=np.zeros(3))
class DoubleQLossesTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.qs = make_tandem_qvals()
self.transition = make_transition()
self.rng_key = jax.random.PRNGKey(42)
@chex.all_variants()
@parameterized.parameters(
('double_q',), ('double_q_v',), ('double_q_p',), ('double_q_pv',),
('q_regression',),
)
def test_active_loss_gradients(self, loss_type):
loss_fn = losses.make_loss_fn(loss_type, active=True)
def fn(q_tm1, q_t, q_t_target, transition, rng_key):
return loss_fn(q_tm1, q_t, q_t_target, transition, rng_key)
grad_fn = self.variant(jax.grad(fn, argnums=(0, 1, 2)))
dldq_tm1, dldq_t, dldq_t_target = grad_fn(
self.qs, self.qs, self.qs, self.transition, self.rng_key)
# Assert that only active net gets nonzero gradients.
self.assertGreater(np.sum(np.abs(dldq_tm1.active.q_values)), 0.)
self.assertTrue(np.all(dldq_t.active.q_values == 0.))
self.assertTrue(np.all(dldq_t_target.active.q_values == 0.))
self.assertTrue(np.all(dldq_t.passive.q_values == 0.))
self.assertTrue(np.all(dldq_tm1.passive.q_values == 0.))
self.assertTrue(np.all(dldq_t_target.passive.q_values == 0.))
@chex.all_variants()
@parameterized.parameters(
('double_q',), ('double_q_v',), ('double_q_p',), ('double_q_pv',),
('q_regression',),
)
def test_passive_loss_gradients(self, loss_type):
loss_fn = losses.make_loss_fn(loss_type, active=False)
def fn(q_tm1, q_t, q_t_target, transition, rng_key):
return loss_fn(q_tm1, q_t, q_t_target, transition, rng_key)
grad_fn = self.variant(jax.grad(fn, argnums=(0, 1, 2)))
dldq_tm1, dldq_t, dldq_t_target = grad_fn(
self.qs, self.qs, self.qs, self.transition, self.rng_key)
# Assert that only passive net gets nonzero gradients.
self.assertGreater(np.sum(np.abs(dldq_tm1.passive.q_values)), 0.)
self.assertTrue(np.all(dldq_t.passive.q_values == 0.))
self.assertTrue(np.all(dldq_t_target.passive.q_values == 0.))
self.assertTrue(np.all(dldq_t.active.q_values == 0.))
self.assertTrue(np.all(dldq_tm1.active.q_values == 0.))
self.assertTrue(np.all(dldq_t_target.active.q_values == 0.))
if __name__ == '__main__':
config.update('jax_numpy_rank_promotion', 'raise')
absltest.main()
+216
View File
@@ -0,0 +1,216 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""DQN agent network components and implementation."""
import typing
from typing import Any, Callable, Tuple, Union
import chex
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
Network = hk.Transformed
Params = hk.Params
NetworkFn = Callable[..., Any]
class QNetworkOutputs(typing.NamedTuple):
q_values: jnp.ndarray
class QRNetworkOutputs(typing.NamedTuple):
q_values: jnp.ndarray
q_dist: jnp.ndarray
NUM_QUANTILES = 201
def _dqn_default_initializer(
num_input_units: int) -> hk.initializers.Initializer:
"""Default initialization scheme inherited from past implementations of DQN.
This scheme was historically used to initialize all weights and biases
in convolutional and linear layers of DQN-type agents' networks.
It initializes each weight as an independent uniform sample from [`-c`, `c`],
where `c = 1 / np.sqrt(num_input_units)`, and `num_input_units` is the number
of input units affecting a single output unit in the given layer, i.e. the
total number of inputs in the case of linear (dense) layers, and
`num_input_channels * kernel_width * kernel_height` in the case of
convolutional layers.
Args:
num_input_units: number of input units to a single output unit of the layer.
Returns:
Haiku weight initializer.
"""
max_val = np.sqrt(1 / num_input_units)
return hk.initializers.RandomUniform(-max_val, max_val)
def make_quantiles():
"""Quantiles for QR-DQN."""
return (jnp.arange(0, NUM_QUANTILES) + 0.5) / float(NUM_QUANTILES)
def conv(
num_features: int,
kernel_shape: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
name=None,
) -> NetworkFn:
"""Convolutional layer with DQN's legacy weight initialization scheme."""
def net_fn(inputs):
"""Function representing conv layer with DQN's legacy initialization."""
num_input_units = inputs.shape[-1] * kernel_shape[0] * kernel_shape[1]
initializer = _dqn_default_initializer(num_input_units)
layer = hk.Conv2D(
num_features,
kernel_shape=kernel_shape,
stride=stride,
w_init=initializer,
b_init=initializer,
padding='VALID',
name=name)
return layer(inputs)
return net_fn
def linear(num_outputs: int, with_bias=True, name=None) -> NetworkFn:
"""Linear layer with DQN's legacy weight initialization scheme."""
def net_fn(inputs):
"""Function representing linear layer with DQN's legacy initialization."""
initializer = _dqn_default_initializer(inputs.shape[-1])
layer = hk.Linear(
num_outputs,
with_bias=with_bias,
w_init=initializer,
b_init=initializer,
name=name)
return layer(inputs)
return net_fn
def linear_with_shared_bias(num_outputs: int, name=None) -> NetworkFn:
"""Linear layer with single shared bias instead of one bias per output."""
def layer_fn(inputs):
"""Function representing a linear layer with single shared bias."""
initializer = _dqn_default_initializer(inputs.shape[-1])
bias_free_linear = hk.Linear(
num_outputs, with_bias=False, w_init=initializer, name=name)
linear_output = bias_free_linear(inputs)
bias = hk.get_parameter('b', [1], inputs.dtype, init=initializer)
bias = jnp.broadcast_to(bias, linear_output.shape)
return linear_output + bias
return layer_fn
def dqn_torso() -> NetworkFn:
"""DQN convolutional torso.
Includes scaling from [`0`, `255`] (`uint8`) to [`0`, `1`] (`float32`)`.
Returns:
Network function that `haiku.transform` can be called on.
"""
def net_fn(inputs):
"""Function representing convolutional torso for a DQN Q-network."""
network = hk.Sequential([
lambda x: x.astype(jnp.float32) / 255.,
conv(32, kernel_shape=(8, 8), stride=(4, 4), name='conv1'),
jax.nn.relu,
conv(64, kernel_shape=(4, 4), stride=(2, 2), name='conv2'),
jax.nn.relu,
conv(64, kernel_shape=(3, 3), stride=(1, 1), name='conv3'),
jax.nn.relu,
hk.Flatten(),
])
return network(inputs)
return net_fn
def dqn_value_head(num_actions: int, shared_bias: bool = False) -> NetworkFn:
"""Regular DQN Q-value head with single hidden layer."""
last_layer = linear_with_shared_bias if shared_bias else linear
def net_fn(inputs):
"""Function representing value head for a DQN Q-network."""
network = hk.Sequential([
linear(512, name='linear1'),
jax.nn.relu,
last_layer(num_actions, name='output'),
])
return network(inputs)
return net_fn
def qr_atari_network(num_actions: int, quantiles: jnp.ndarray) -> NetworkFn:
"""QR-DQN network, expects `uint8` input."""
chex.assert_rank(quantiles, 1)
num_quantiles = len(quantiles)
def net_fn(inputs):
"""Function representing QR-DQN Q-network."""
network = hk.Sequential([
dqn_torso(),
dqn_value_head(num_quantiles * num_actions),
])
network_output = network(inputs)
q_dist = jnp.reshape(network_output, (-1, num_quantiles, num_actions))
q_values = jnp.mean(q_dist, axis=1)
q_values = jax.lax.stop_gradient(q_values)
return QRNetworkOutputs(q_dist=q_dist, q_values=q_values)
return net_fn
def double_dqn_atari_network(num_actions: int) -> NetworkFn:
"""DQN network with shared bias in final layer, expects `uint8` input."""
def net_fn(inputs):
"""Function representing DQN Q-network with shared bias output layer."""
network = hk.Sequential([
dqn_torso(),
dqn_value_head(num_actions, shared_bias=True),
])
return QNetworkOutputs(q_values=network(inputs))
return net_fn
def make_network(network_type: str, num_actions: int) -> Network:
"""Constructs network."""
if network_type == 'double_q':
network_fn = double_dqn_atari_network(num_actions)
elif network_type == 'qr':
quantiles = make_quantiles()
network_fn = qr_atari_network(num_actions, quantiles)
else:
raise ValueError('Unknown network "{}"'.format(network_type))
return hk.transform(network_fn)
+487
View File
@@ -0,0 +1,487 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Components for DQN."""
import abc
import collections
import csv
import os
import timeit
from typing import Any, Iterable, Mapping, Optional, Text, Tuple, Union
import dm_env
import jax
import jax.numpy as jnp
import numpy as np
import rlax
from tandem_dqn import networks
from tandem_dqn import processors
Action = int
Network = networks.Network
NetworkParams = networks.Params
PRNGKey = jnp.ndarray # A size 2 array.
class Agent(abc.ABC):
"""Agent interface."""
@abc.abstractmethod
def step(self, timestep: dm_env.TimeStep) -> Action:
"""Selects action given timestep and potentially learns."""
@abc.abstractmethod
def reset(self) -> None:
"""Resets the agent's episodic state such as frame stack and action repeat.
This method should be called at the beginning of every episode.
"""
@abc.abstractmethod
def get_state(self) -> Mapping[Text, Any]:
"""Retrieves agent state as a dictionary (e.g. for serialization)."""
@abc.abstractmethod
def set_state(self, state: Mapping[Text, Any]) -> None:
"""Sets agent state from a (potentially de-serialized) dictionary."""
@property
@abc.abstractmethod
def statistics(self) -> Mapping[Text, float]:
"""Returns current agent statistics as a dictionary."""
def run_loop(
agent: Agent,
environment: dm_env.Environment,
max_steps_per_episode: int = 0,
yield_before_reset: bool = False,
) -> Iterable[Tuple[dm_env.Environment, Optional[dm_env.TimeStep], Agent,
Optional[Action]]]:
"""Repeatedly alternates step calls on environment and agent.
At time `t`, `t + 1` environment timesteps and `t + 1` agent steps have been
seen in the current episode. `t` resets to `0` for the next episode.
Args:
agent: Agent to be run, has methods `step(timestep)` and `reset()`.
environment: Environment to run, has methods `step(action)` and `reset()`.
max_steps_per_episode: If positive, when time t reaches this value within an
episode, the episode is truncated.
yield_before_reset: Whether to additionally yield `(environment, None,
agent, None)` before the agent and environment is reset at the start of
each episode.
Yields:
Tuple `(environment, timestep_t, agent, a_t)` where
`a_t = agent.step(timestep_t)`.
"""
while True: # For each episode.
if yield_before_reset:
yield environment, None, agent, None,
t = 0
agent.reset()
timestep_t = environment.reset() # timestep_0.
while True: # For each step in the current episode.
a_t = agent.step(timestep_t)
yield environment, timestep_t, agent, a_t
# Update t after one environment step and agent step and relabel.
t += 1
a_tm1 = a_t
timestep_t = environment.step(a_tm1)
if max_steps_per_episode > 0 and t >= max_steps_per_episode:
assert t == max_steps_per_episode
timestep_t = timestep_t._replace(step_type=dm_env.StepType.LAST)
if timestep_t.last():
unused_a_t = agent.step(timestep_t) # Extra agent step, action ignored.
yield environment, timestep_t, agent, None
break
def generate_statistics(
trackers: Iterable[Any],
timestep_action_sequence: Iterable[Tuple[dm_env.Environment,
Optional[dm_env.TimeStep], Agent,
Optional[Action]]]
) -> Mapping[Text, Any]:
"""Generates statistics from a sequence of timestep and actions."""
# Only reset at the start, not between episodes.
for tracker in trackers:
tracker.reset()
for environment, timestep_t, agent, a_t in timestep_action_sequence:
for tracker in trackers:
tracker.step(environment, timestep_t, agent, a_t)
# Merge all statistics dictionaries into one.
statistics_dicts = (tracker.get() for tracker in trackers)
return dict(collections.ChainMap(*statistics_dicts))
class EpisodeTracker:
"""Tracks episode return and other statistics."""
def __init__(self):
self._num_steps_since_reset = None
self._num_steps_over_episodes = None
self._episode_returns = None
self._current_episode_rewards = None
self._current_episode_step = None
def step(
self,
environment: Optional[dm_env.Environment],
timestep_t: dm_env.TimeStep,
agent: Optional[Agent],
a_t: Optional[Action],
) -> None:
"""Accumulates statistics from timestep."""
del (environment, agent, a_t)
if timestep_t.first():
if self._current_episode_rewards:
raise ValueError('Current episode reward list should be empty.')
if self._current_episode_step != 0:
raise ValueError('Current episode step should be zero.')
else:
# First reward is invalid, all other rewards are appended.
self._current_episode_rewards.append(timestep_t.reward)
self._num_steps_since_reset += 1
self._current_episode_step += 1
if timestep_t.last():
self._episode_returns.append(sum(self._current_episode_rewards))
self._current_episode_rewards = []
self._num_steps_over_episodes += self._current_episode_step
self._current_episode_step = 0
def reset(self) -> None:
"""Resets all gathered statistics, not to be called between episodes."""
self._num_steps_since_reset = 0
self._num_steps_over_episodes = 0
self._episode_returns = []
self._current_episode_step = 0
self._current_episode_rewards = []
def get(self) -> Mapping[Text, Union[int, float, None]]:
"""Aggregates statistics and returns as a dictionary.
Here the convention is `episode_return` is set to `current_episode_return`
if a full episode has not been encountered. Otherwise it is set to
`mean_episode_return` which is the mean return of complete episodes only. If
no steps have been taken at all, `episode_return` is set to `NaN`.
Returns:
A dictionary of aggregated statistics.
"""
if self._episode_returns:
mean_episode_return = np.array(self._episode_returns).mean()
current_episode_return = sum(self._current_episode_rewards)
episode_return = mean_episode_return
else:
mean_episode_return = np.nan
if self._num_steps_since_reset > 0:
current_episode_return = sum(self._current_episode_rewards)
else:
current_episode_return = np.nan
episode_return = current_episode_return
return {
'mean_episode_return': mean_episode_return,
'current_episode_return': current_episode_return,
'episode_return': episode_return,
'num_episodes': len(self._episode_returns),
'num_steps_over_episodes': self._num_steps_over_episodes,
'current_episode_step': self._current_episode_step,
'num_steps_since_reset': self._num_steps_since_reset,
}
class StepRateTracker:
"""Tracks step rate, number of steps taken and duration since last reset."""
def __init__(self):
self._num_steps_since_reset = None
self._start = None
def step(
self,
environment: Optional[dm_env.Environment],
timestep_t: Optional[dm_env.TimeStep],
agent: Optional[Agent],
a_t: Optional[Action],
) -> None:
del (environment, timestep_t, agent, a_t)
self._num_steps_since_reset += 1
def reset(self) -> None:
self._num_steps_since_reset = 0
self._start = timeit.default_timer()
def get(self) -> Mapping[Text, float]:
duration = timeit.default_timer() - self._start
if self._num_steps_since_reset > 0:
step_rate = self._num_steps_since_reset / duration
else:
step_rate = np.nan
return {
'step_rate': step_rate,
'num_steps': self._num_steps_since_reset,
'duration': duration,
}
class UnbiasedExponentialWeightedAverageAgentTracker:
"""'Unbiased Constant-Step-Size Trick' from the Sutton and Barto RL book."""
def __init__(self, step_size: float, initial_agent: Agent):
self._initial_statistics = dict(initial_agent.statistics)
self._step_size = step_size
self.trace = 0.
self._statistics = dict(self._initial_statistics)
def step(
self,
environment: Optional[dm_env.Environment],
timestep_t: Optional[dm_env.TimeStep],
agent: Agent,
a_t: Optional[Action],
) -> None:
"""Accumulates agent statistics."""
del (environment, timestep_t, a_t)
self.trace = (1 - self._step_size) * self.trace + self._step_size
final_step_size = self._step_size / self.trace
assert 0 <= final_step_size <= 1
if final_step_size == 1:
# Since the self._initial_statistics is likely to be NaN and
# 0 * NaN == NaN just replace self._statistics on the first step.
self._statistics = dict(agent.statistics)
else:
self._statistics = jax.tree_multimap(
lambda s, x: (1 - final_step_size) * s + final_step_size * x,
self._statistics, agent.statistics)
def reset(self) -> None:
"""Resets statistics and internal state."""
self.trace = 0.
# get() may be called before step() so ensure statistics are initialized.
self._statistics = dict(self._initial_statistics)
def get(self) -> Mapping[Text, float]:
"""Returns current accumulated statistics."""
return self._statistics
def make_default_trackers(initial_agent: Agent):
return [
EpisodeTracker(),
StepRateTracker(),
UnbiasedExponentialWeightedAverageAgentTracker(
step_size=1e-3, initial_agent=initial_agent),
]
class EpsilonGreedyActor(Agent):
"""Agent that acts with a given set of Q-network parameters and epsilon.
Network parameters are set on the actor. The actor can be serialized,
ensuring determinism of execution (e.g. when checkpointing).
"""
def __init__(
self,
preprocessor: processors.Processor,
network: Network,
exploration_epsilon: float,
rng_key: PRNGKey,
):
self._preprocessor = preprocessor
self._rng_key = rng_key
self._action = None
self.network_params = None # Nest of arrays (haiku.Params), set externally.
def select_action(rng_key, network_params, s_t):
"""Samples action from eps-greedy policy wrt Q-values at given state."""
rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0]
a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon)
return rng_key, a_t
self._select_action = jax.jit(select_action)
def step(self, timestep: dm_env.TimeStep) -> Action:
"""Selects action given a timestep."""
timestep = self._preprocessor(timestep)
if timestep is None: # Repeat action.
return self._action
s_t = timestep.observation
self._rng_key, a_t = self._select_action(self._rng_key, self.network_params,
s_t)
self._action = Action(jax.device_get(a_t))
return self._action
def reset(self) -> None:
"""Resets the agent's episodic state such as frame stack and action repeat.
This method should be called at the beginning of every episode.
"""
processors.reset(self._preprocessor)
self._action = None
def get_state(self) -> Mapping[Text, Any]:
"""Retrieves agent state as a dictionary (e.g. for serialization)."""
# State contains network params to make agent easy to run from a checkpoint.
return {
'rng_key': self._rng_key,
'network_params': self.network_params,
}
def set_state(self, state: Mapping[Text, Any]) -> None:
"""Sets agent state from a (potentially de-serialized) dictionary."""
self._rng_key = state['rng_key']
self.network_params = state['network_params']
@property
def statistics(self) -> Mapping[Text, float]:
return {}
class LinearSchedule:
"""Linear schedule, used for exploration epsilon in DQN agents."""
def __init__(self,
begin_value,
end_value,
begin_t,
end_t=None,
decay_steps=None):
if (end_t is None) == (decay_steps is None):
raise ValueError('Exactly one of end_t, decay_steps must be provided.')
self._decay_steps = decay_steps if end_t is None else end_t - begin_t
self._begin_t = begin_t
self._begin_value = begin_value
self._end_value = end_value
def __call__(self, t):
"""Implements a linear transition from a begin to an end value."""
frac = min(max(t - self._begin_t, 0), self._decay_steps) / self._decay_steps
return (1 - frac) * self._begin_value + frac * self._end_value
class NullWriter:
"""A placeholder logging object that does nothing."""
def write(self, *args, **kwargs) -> None:
pass
def close(self) -> None:
pass
class CsvWriter:
"""A logging object writing to a CSV file.
Each `write()` takes a `OrderedDict`, creating one column in the CSV file for
each dictionary key on the first call. Successive calls to `write()` must
contain the same dictionary keys.
"""
def __init__(self, fname: Text):
"""Initializes a `CsvWriter`.
Args:
fname: File name (path) for file to be written to.
"""
dirname = os.path.dirname(fname)
if not os.path.exists(dirname):
os.makedirs(dirname)
self._fname = fname
self._header_written = False
self._fieldnames = None
def write(self, values: collections.OrderedDict) -> None:
"""Appends given values as new row to CSV file."""
if self._fieldnames is None:
self._fieldnames = values.keys()
# Open a file in 'append' mode, so we can continue logging safely to the
# same file after e.g. restarting from a checkpoint.
with open(self._fname, 'a') as file:
# Always use same fieldnames to create writer, this way a consistency
# check is performed automatically on each write.
writer = csv.DictWriter(file, fieldnames=self._fieldnames)
# Write a header if this is the very first write.
if not self._header_written:
writer.writeheader()
self._header_written = True
writer.writerow(values)
def close(self) -> None:
"""Closes the `CsvWriter`."""
pass
def get_state(self) -> Mapping[Text, Any]:
"""Retrieves `CsvWriter` state as a `dict` (e.g. for serialization)."""
return {
'header_written': self._header_written,
'fieldnames': self._fieldnames
}
def set_state(self, state: Mapping[Text, Any]) -> None:
"""Sets `CsvWriter` state from a (potentially de-serialized) dictionary."""
self._header_written = state['header_written']
self._fieldnames = state['fieldnames']
class NullCheckpoint:
"""A placeholder checkpointing object that does nothing.
Can be used as a substitute for an actual checkpointing object when
checkpointing is disabled.
"""
def __init__(self):
self.state = AttributeDict()
def save(self) -> None:
pass
def can_be_restored(self) -> bool:
return False
def restore(self) -> None:
pass
class AttributeDict(dict):
"""A `dict` that supports getting, setting, deleting keys via attributes."""
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
self[key] = value
def __delattr__(self, key):
del self[key]
File diff suppressed because it is too large Load Diff
+167
View File
@@ -0,0 +1,167 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Replay components for DQN-type agents."""
import collections
import typing
from typing import Any, Callable, Generic, Iterable, List, Mapping, Optional, Sequence, Text, Tuple, TypeVar
import dm_env
import numpy as np
import snappy
from tandem_dqn import parts
CompressedArray = Tuple[bytes, Tuple, np.dtype]
# Generic replay structure: Any flat named tuple.
ReplayStructure = TypeVar('ReplayStructure', bound=Tuple[Any, ...])
class Transition(typing.NamedTuple):
s_tm1: Optional[np.ndarray]
a_tm1: Optional[parts.Action]
r_t: Optional[float]
discount_t: Optional[float]
s_t: Optional[np.ndarray]
a_t: Optional[parts.Action] = None
mc_return_tm1: Optional[float] = None
class TransitionReplay(Generic[ReplayStructure]):
"""Uniform replay, with circular buffer storage for flat named tuples."""
def __init__(self,
capacity: int,
structure: ReplayStructure,
random_state: np.random.RandomState,
encoder: Optional[Callable[[ReplayStructure], Any]] = None,
decoder: Optional[Callable[[Any], ReplayStructure]] = None):
self._capacity = capacity
self._structure = structure
self._random_state = random_state
self._encoder = encoder or (lambda s: s)
self._decoder = decoder or (lambda s: s)
self._storage = [None] * capacity
self._num_added = 0
def add(self, item: ReplayStructure) -> None:
"""Adds single item to replay."""
self._storage[self._num_added % self._capacity] = self._encoder(item)
self._num_added += 1
def get(self, indices: Sequence[int]) -> List[ReplayStructure]:
"""Retrieves items by indices."""
return [self._decoder(self._storage[i]) for i in indices]
def sample(self, size: int) -> ReplayStructure:
"""Samples batch of items from replay uniformly, with replacement."""
indices = self._random_state.choice(self.size, size=size, replace=True)
samples = self.get(indices)
transposed = zip(*samples)
stacked = [np.stack(xs, axis=0) for xs in transposed]
return type(self._structure)(*stacked) # pytype: disable=not-callable
@property
def size(self) -> int:
"""Number of items currently contained in replay."""
return min(self._num_added, self._capacity)
@property
def capacity(self) -> int:
"""Total capacity of replay (max number of items stored at any one time)."""
return self._capacity
def get_state(self) -> Mapping[Text, Any]:
"""Retrieves replay state as a dictionary (e.g. for serialization)."""
return {
'storage': self._storage,
'num_added': self._num_added,
}
def set_state(self, state: Mapping[Text, Any]) -> None:
"""Sets replay state from a (potentially de-serialized) dictionary."""
self._storage = state['storage']
self._num_added = state['num_added']
class TransitionAccumulatorWithMCReturn:
"""Accumulates timesteps to transitions with MC returns."""
def __init__(self):
self._transitions = collections.deque()
self.reset()
def step(self, timestep_t: dm_env.TimeStep,
a_t: parts.Action) -> Iterable[Transition]:
"""Accumulates timestep and resulting action, maybe yields transitions."""
if timestep_t.first():
self.reset()
# There are no transitions on the first timestep.
if self._timestep_tm1 is None:
assert self._a_tm1 is None
if not timestep_t.first():
raise ValueError('Expected FIRST timestep, got %s.' % str(timestep_t))
self._timestep_tm1 = timestep_t
self._a_tm1 = a_t
return # Empty iterable.
self._transitions.append(
Transition(
s_tm1=self._timestep_tm1.observation,
a_tm1=self._a_tm1,
r_t=timestep_t.reward,
discount_t=timestep_t.discount,
s_t=timestep_t.observation,
a_t=a_t,
mc_return_tm1=None,
))
self._timestep_tm1 = timestep_t
self._a_tm1 = a_t
if timestep_t.last():
# Annotate all episode transitions with their MC returns.
mc_return = 0
mc_transitions = []
while self._transitions:
transition = self._transitions.pop()
mc_return = transition.discount_t * mc_return + transition.r_t
mc_transitions.append(transition._replace(mc_return_tm1=mc_return))
for transition in reversed(mc_transitions):
yield transition
else:
# Wait for episode end before yielding anything.
return
def reset(self) -> None:
"""Resets the accumulator. Following timestep is expected to be FIRST."""
self._transitions.clear()
self._timestep_tm1 = None
self._a_tm1 = None
def compress_array(array: np.ndarray) -> CompressedArray:
"""Compresses a numpy array with snappy."""
return snappy.compress(array), array.shape, array.dtype
def uncompress_array(compressed: CompressedArray) -> np.ndarray:
"""Uncompresses a numpy array with snappy given its shape and dtype."""
compressed_array, shape, dtype = compressed
byte_string = snappy.uncompress(compressed_array)
return np.frombuffer(byte_string, dtype=dtype).reshape(shape)
+38
View File
@@ -0,0 +1,38 @@
# DeepMind libraries.
chex==0.0.2
dm-env==1.2
dm-haiku @ git+git://github.com/deepmind/dm-haiku.git@db991d56563221d5a06be5e7228155e53d01aba9
dm-tree==0.1.5
optax==0.0.1
rlax @ git+git://github.com/deepmind/rlax.git@870cba1ea8ad36725f4f3a790846298657b6fd4b
# JAX libraries with GPU support.
jax==0.1.72
jaxlib==0.1.50
# ALE and Gym libraries.
atari-py==0.2.6
gym==0.13.1
# Base dependencies.
absl-py==0.9.0
numpy==1.18.0
Pillow==7.1.2
python-snappy==0.6
# Transitive dependencies.
asn1crypto==0.24.0
cloudpickle==1.2.2
cryptography==2.1.4
dataclasses==0.7
future==0.18.2
idna==2.6
keyring==10.6.0
keyrings.alt==3.0
opencv-python==4.2.0.34
opt-einsum==3.2.1
pycrypto==2.6.1
pyxdg==0.25
SecretStorage==2.3.1
six==1.15.0
toolz==0.11.1
+46
View File
@@ -0,0 +1,46 @@
#!/bin/bash
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script installs tandem_dqn in a clean virtualenv and runs an
# example training loop. It is designed to be run from the parent directory,
# e.g.:
#
# git clone git@github.com:deepmind/deepmind-research.git
# cd deepmind_research
# tandem_dqn/run.sh
# Fail on errors, display commands.
set -e
set -x
TMP_DIR=`mktemp -d`
# Set up environment.
virtualenv --python=python3.6 "${TMP_DIR}/env"
source "${TMP_DIR}/env/bin/activate"
# Install deps.
pip install --upgrade -r tandem_dqn/requirements.txt
# Run small-scale example for a few thousand steps.
python -m tandem_dqn.run_tandem \
--environment_name=space_invaders \
--replay_capacity=10000 --target_network_update_period=40 \
--num_iterations=2 --num_train_frames=5000 --num_eval_frames=500 \
--jax_platform_name=cpu
# Clean up.
rm -r ${TMP_DIR}
echo "Test run finished."
+356
View File
@@ -0,0 +1,356 @@
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Tandem DQN agent implemented in JAX, training on Atari."""
import collections
import itertools
import sys
import typing
from absl import app
from absl import flags
from absl import logging
import dm_env
import jax
from jax.config import config
import numpy as np
import optax
from tandem_dqn import agent as agent_lib
from tandem_dqn import atari_data
from tandem_dqn import gym_atari
from tandem_dqn import losses
from tandem_dqn import networks
from tandem_dqn import parts
from tandem_dqn import processors
from tandem_dqn import replay as replay_lib
# Relevant flag values are expressed in terms of environment frames.
FLAGS = flags.FLAGS
flags.DEFINE_string('environment_name', 'pong', '')
flags.DEFINE_boolean('use_sticky_actions', False, '')
flags.DEFINE_integer('environment_height', 84, '')
flags.DEFINE_integer('environment_width', 84, '')
flags.DEFINE_integer('replay_capacity', int(1e6), '')
flags.DEFINE_bool('compress_state', True, '')
flags.DEFINE_float('min_replay_capacity_fraction', 0.05, '')
flags.DEFINE_integer('batch_size', 32, '')
flags.DEFINE_integer('max_frames_per_episode', 108000, '') # 30 mins.
flags.DEFINE_integer('num_action_repeats', 4, '')
flags.DEFINE_integer('num_stacked_frames', 4, '')
flags.DEFINE_float('exploration_epsilon_begin_value', 1., '')
flags.DEFINE_float('exploration_epsilon_end_value', 0.01, '')
flags.DEFINE_float('exploration_epsilon_decay_frame_fraction', 0.02, '')
flags.DEFINE_float('eval_exploration_epsilon', 0.01, '')
flags.DEFINE_integer('target_network_update_period', int(1.2e5), '')
flags.DEFINE_float('additional_discount', 0.99, '')
flags.DEFINE_float('max_abs_reward', 1., '')
flags.DEFINE_integer('seed', 1, '') # GPU may introduce nondeterminism.
flags.DEFINE_integer('num_iterations', 200, '')
flags.DEFINE_integer('num_train_frames', int(1e6), '') # Per iteration.
flags.DEFINE_integer('num_eval_frames', int(5e5), '') # Per iteration.
flags.DEFINE_integer('learn_period', 16, '')
flags.DEFINE_string('results_csv_path', '/tmp/results.csv', '')
# Tandem-specific parameters.
# Using fixed configs for optimizers:
# RMSProp: lr = 0.00025, eps=0.01 / (32 ** 2)
# ADAM: lr = 0.00005, eps=0.01 / 32
_OPTIMIZERS = ['rmsprop', 'adam']
flags.DEFINE_enum('optimizer_active', 'rmsprop', _OPTIMIZERS, '')
flags.DEFINE_enum('optimizer_passive', 'rmsprop', _OPTIMIZERS, '')
_NETWORKS = ['double_q', 'qr']
flags.DEFINE_enum('network_active', 'double_q', _NETWORKS, '')
flags.DEFINE_enum('network_passive', 'double_q', _NETWORKS, '')
_LOSSES = ['double_q', 'double_q_v', 'double_q_p', 'double_q_pv', 'qr',
'q_regression']
flags.DEFINE_enum('loss_active', 'double_q', _LOSSES, '')
flags.DEFINE_enum('loss_passive', 'double_q', _LOSSES, '')
flags.DEFINE_integer('tied_layers', 0, '')
TandemTuple = agent_lib.TandemTuple
def make_optimizer(optimizer_type):
"""Constructs optimizer."""
if optimizer_type == 'rmsprop':
learning_rate = 0.00025
epsilon = 0.01 / (32**2)
optimizer = optax.rmsprop(
learning_rate=learning_rate,
decay=0.95,
eps=epsilon,
centered=True)
elif optimizer_type == 'adam':
learning_rate = 0.00005
epsilon = 0.01 / 32
optimizer = optax.adam(
learning_rate=learning_rate,
eps=epsilon)
else:
raise ValueError('Unknown optimizer "{}"'.format(optimizer_type))
return optimizer
def main(argv):
"""Trains Tandem DQN agent on Atari."""
del argv
logging.info('Tandem DQN on Atari on %s.',
jax.lib.xla_bridge.get_backend().platform)
random_state = np.random.RandomState(FLAGS.seed)
rng_key = jax.random.PRNGKey(
random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64))
if FLAGS.results_csv_path:
writer = parts.CsvWriter(FLAGS.results_csv_path)
else:
writer = parts.NullWriter()
def environment_builder():
"""Creates Atari environment."""
env = gym_atari.GymAtari(
FLAGS.environment_name,
sticky_actions=FLAGS.use_sticky_actions,
seed=random_state.randint(1, 2**32))
return gym_atari.RandomNoopsEnvironmentWrapper(
env,
min_noop_steps=1,
max_noop_steps=30,
seed=random_state.randint(1, 2**32),
)
env = environment_builder()
logging.info('Environment: %s', FLAGS.environment_name)
logging.info('Action spec: %s', env.action_spec())
logging.info('Observation spec: %s', env.observation_spec())
num_actions = env.action_spec().num_values
# Check: qr network and qr losses can only be used together.
if ('qr' in FLAGS.network_active) != ('qr' in FLAGS.loss_active):
raise ValueError('Active loss/net must either both use QR, or neither.')
if ('qr' in FLAGS.network_passive) != ('qr' in FLAGS.loss_passive):
raise ValueError('Passive loss/net must either both use QR, or neither.')
network = TandemTuple(
active=networks.make_network(FLAGS.network_active, num_actions),
passive=networks.make_network(FLAGS.network_passive, num_actions),
)
loss = TandemTuple(
active=losses.make_loss_fn(FLAGS.loss_active, active=True),
passive=losses.make_loss_fn(FLAGS.loss_passive, active=False),
)
# Tied layers.
assert 0 <= FLAGS.tied_layers <= 4
if FLAGS.tied_layers > 0 and (FLAGS.network_passive != 'double_q'
or FLAGS.network_active != 'double_q'):
raise ValueError('Tied layers > 0 is only supported for double_q networks.')
layers = [
'sequential/sequential/conv1',
'sequential/sequential/conv2',
'sequential/sequential/conv3',
'sequential/sequential_1/linear1'
]
tied_layers = set(layers[:FLAGS.tied_layers])
def preprocessor_builder():
return processors.atari(
additional_discount=FLAGS.additional_discount,
max_abs_reward=FLAGS.max_abs_reward,
resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
num_action_repeats=FLAGS.num_action_repeats,
num_pooled_frames=2,
zero_discount_on_life_loss=True,
num_stacked_frames=FLAGS.num_stacked_frames,
grayscaling=True,
)
# Create sample network input from sample preprocessor output.
sample_processed_timestep = preprocessor_builder()(env.reset())
sample_processed_timestep = typing.cast(dm_env.TimeStep,
sample_processed_timestep)
sample_network_input = sample_processed_timestep.observation
assert sample_network_input.shape == (FLAGS.environment_height,
FLAGS.environment_width,
FLAGS.num_stacked_frames)
exploration_epsilon_schedule = parts.LinearSchedule(
begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity *
FLAGS.num_action_repeats),
decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
FLAGS.num_iterations * FLAGS.num_train_frames),
begin_value=FLAGS.exploration_epsilon_begin_value,
end_value=FLAGS.exploration_epsilon_end_value)
if FLAGS.compress_state:
def encoder(transition):
return transition._replace(
s_tm1=replay_lib.compress_array(transition.s_tm1),
s_t=replay_lib.compress_array(transition.s_t))
def decoder(transition):
return transition._replace(
s_tm1=replay_lib.uncompress_array(transition.s_tm1),
s_t=replay_lib.uncompress_array(transition.s_t))
else:
encoder = None
decoder = None
replay_structure = replay_lib.Transition(
s_tm1=None,
a_tm1=None,
r_t=None,
discount_t=None,
s_t=None,
a_t=None,
mc_return_tm1=None,
)
replay = replay_lib.TransitionReplay(FLAGS.replay_capacity, replay_structure,
random_state, encoder, decoder)
optimizer = TandemTuple(
active=make_optimizer(FLAGS.optimizer_active),
passive=make_optimizer(FLAGS.optimizer_passive),
)
train_rng_key, eval_rng_key = jax.random.split(rng_key)
train_agent = agent_lib.TandemDqn(
preprocessor=preprocessor_builder(),
sample_network_input=sample_network_input,
network=network,
optimizer=optimizer,
loss=loss,
transition_accumulator=replay_lib.TransitionAccumulatorWithMCReturn(),
replay=replay,
batch_size=FLAGS.batch_size,
exploration_epsilon=exploration_epsilon_schedule,
min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
learn_period=FLAGS.learn_period,
target_network_update_period=FLAGS.target_network_update_period,
tied_layers=tied_layers,
rng_key=train_rng_key,
)
eval_agent_active = parts.EpsilonGreedyActor(
preprocessor=preprocessor_builder(),
network=network.active,
exploration_epsilon=FLAGS.eval_exploration_epsilon,
rng_key=eval_rng_key)
eval_agent_passive = parts.EpsilonGreedyActor(
preprocessor=preprocessor_builder(),
network=network.passive,
exploration_epsilon=FLAGS.eval_exploration_epsilon,
rng_key=eval_rng_key)
# Set up checkpointing.
checkpoint = parts.NullCheckpoint()
state = checkpoint.state
state.iteration = 0
state.train_agent = train_agent
state.eval_agent_active = eval_agent_active
state.eval_agent_passive = eval_agent_passive
state.random_state = random_state
state.writer = writer
if checkpoint.can_be_restored():
checkpoint.restore()
# Run single iteration of training or evaluation.
def run_iteration(agent, env, num_frames):
seq = parts.run_loop(agent, env, FLAGS.max_frames_per_episode)
seq_truncated = itertools.islice(seq, num_frames)
trackers = parts.make_default_trackers(agent)
return parts.generate_statistics(trackers, seq_truncated)
def eval_log_output(eval_stats, suffix):
human_normalized_score = atari_data.get_human_normalized_score(
FLAGS.environment_name, eval_stats['episode_return'])
capped_human_normalized_score = np.amin([1., human_normalized_score])
return [
('eval_episode_return_' + suffix,
eval_stats['episode_return'], '% 2.2f'),
('eval_num_episodes_' + suffix,
eval_stats['num_episodes'], '%3d'),
('eval_frame_rate_' + suffix,
eval_stats['step_rate'], '%4.0f'),
('normalized_return_' + suffix,
human_normalized_score, '%.3f'),
('capped_normalized_return_' + suffix,
capped_human_normalized_score, '%.3f'),
('human_gap_' + suffix,
1. - capped_human_normalized_score, '%.3f'),
]
while state.iteration <= FLAGS.num_iterations:
# New environment for each iteration to allow for determinism if preempted.
env = environment_builder()
# Set agent to train active and passive nets on each learning step.
train_agent.set_training_mode('active_passive')
logging.info('Training iteration %d.', state.iteration)
num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
train_stats = run_iteration(train_agent, env, num_train_frames)
logging.info('Evaluation iteration %d - active agent.', state.iteration)
eval_agent_active.network_params = train_agent.online_params.active
eval_stats_active = run_iteration(eval_agent_active, env,
FLAGS.num_eval_frames)
logging.info('Evaluation iteration %d - passive agent.', state.iteration)
eval_agent_passive.network_params = train_agent.online_params.passive
eval_stats_passive = run_iteration(eval_agent_passive, env,
FLAGS.num_eval_frames)
# Logging and checkpointing.
agent_logs = [
'loss_active',
'loss_passive',
'frac_diff_argmax',
'mc_error_active',
'mc_error_passive',
'mc_error_abs_active',
'mc_error_abs_passive',
]
log_output = (
eval_log_output(eval_stats_active, 'active') +
eval_log_output(eval_stats_passive, 'passive') +
[('iteration', state.iteration, '%3d'),
('frame', state.iteration * FLAGS.num_train_frames, '%5d'),
('train_episode_return', train_stats['episode_return'], '% 2.2f'),
('train_num_episodes', train_stats['num_episodes'], '%3d'),
('train_frame_rate', train_stats['step_rate'], '%4.0f'),
] +
[(k, train_stats[k], '% 2.2f') for k in agent_logs]
)
log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output)
logging.info(log_output_str)
writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
state.iteration += 1
checkpoint.save()
writer.close()
if __name__ == '__main__':
config.update('jax_platform_name', 'gpu') # Default to GPU.
config.update('jax_numpy_rank_promotion', 'raise')
config.config_with_absl()
app.run(main)