mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-20 11:22:09 +08:00
Added tandem_dqn to deepmind-research repo
PiperOrigin-RevId: 405653277
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
f6da52cb38
commit
5f6b11299e
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
|
||||
|
||||
## Projects
|
||||
|
||||
* [The Difficulty of Passive Learning in Deep Reinforcement Learning](tandem_dqn), NeurIPS 2021
|
||||
* [Skilful precipitation nowcasting using deep generative models of radar](nowcasting), Nature 2021
|
||||
* [Compute-Aided Design as Language](cadl)
|
||||
* [Encoders and ensembles for continual learning](continual_learning)
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# Tandem DQN
|
||||
|
||||
This repository provides an implementation of the Tandem DQN agent and
|
||||
main experiments presented in the paper
|
||||
"The Difficulty of Passive Learning in Deep Reinforcement Learning"
|
||||
(Georg Ostrovski, Pablo Samuel Castro and Will Dabney, 2021).
|
||||
|
||||
The code is a modified fork of the [DoubleDQN](https://arxiv.org/abs/1509.06461)
|
||||
agent in the [DQN Zoo](https://github.com/deepmind/dqn_zoo) agent collection.
|
||||
|
||||
## Quick start
|
||||
|
||||
This code can be run with a regular CPU setup, but will only be reasonably fast
|
||||
(with the given configuration) if run with a GPU accelerator, in which case
|
||||
you will need an NVIDIA GPU with recent CUDA drivers.
|
||||
For installation, run
|
||||
|
||||
```shell
|
||||
git clone git@github.com:deepmind/deepmind-research.git
|
||||
virtualenv --python=python3.6 "tandem"
|
||||
source tandem/bin/activate
|
||||
cd deepmind_research
|
||||
pip install -r tandem_dqn/requirements.txt
|
||||
```
|
||||
|
||||
The code has only been tested with Python 3.6.14
|
||||
|
||||
## Running the experiments presented in the paper
|
||||
|
||||
To execute the vanilla Tandem DQN experiment on the Pong environment, run:
|
||||
|
||||
```bash
|
||||
python -m "tandem_dqn.run_tandem" --environment_name=pong --seed=42
|
||||
```
|
||||
|
||||
A number of flags can be specified to customize execution and run various of
|
||||
the presented experimental variations:
|
||||
|
||||
* `--use_sticky_actions`
|
||||
(values: `True`, `False`; default: `False`):
|
||||
Use "sticky actions" variant
|
||||
([Machado et al 2017](https://arxiv.org/abs/1709.06009)) of the Atari environment.
|
||||
|
||||
* `--network_active`, `--network_passive`
|
||||
(values: `double_q`, `qr`; default: `double_q`):
|
||||
Whether to use the DoubleDQN or QR-DQN
|
||||
([Dabney et al 2017](https://arxiv.org/abs/1710.10044)) network architecture
|
||||
for active and passive agents (can be set independently). Note this value
|
||||
needs to be compatible with the chosen loss (see below).
|
||||
|
||||
* `--loss_active`, `--loss_passive`
|
||||
(values: `double_q`, `double_q_v`, `double_q_p`, `double_q_pv`, `qr`,
|
||||
`q_regression`; default: `double_q`)
|
||||
Which loss to use for active and passive agent training. The losses are:
|
||||
* `double_q`: regular Double-Q-Learning loss.
|
||||
* `double_q_v`: Double-Q-Learning with bootstrap target _values_ provided
|
||||
by the respective _other_ agent's target network.
|
||||
* `double_q_p`: Double-Q-Learning with bootstrap target _policy_ (argmax)
|
||||
provided by the respective _other_ agent's online network.
|
||||
* `double_q_pv`: Double-Q-Learning with both boostrap target values and
|
||||
policy provided by the respective _other_ agent's target and online networks.
|
||||
* `qr`: Quantile Regression Q-Learning loss
|
||||
([Dabney et al 2017](https://arxiv.org/abs/1710.10044)).
|
||||
* `q_regression`: Supervised regression loss towards the respective other
|
||||
agent's online network's output values.
|
||||
|
||||
* `--optimizer_active`, `--optimizer_passive`
|
||||
(values: `adam`, `rmsprop`; default: `rmsprop` for both):
|
||||
Which optimization algorithm to use for the active and passive network training
|
||||
(can be set independently).
|
||||
|
||||
* `--exploration_epsilon_end_value`
|
||||
(values in `[0, 1]`; default: `0.01`):
|
||||
Value of epsilon parameter in active agent's epsilong-greedy policy after an
|
||||
initial decay phase.
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,378 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tandem DQN agent class."""
|
||||
|
||||
import typing
|
||||
from typing import Any, Callable, Mapping, Set, Text
|
||||
|
||||
from absl import logging
|
||||
import dm_env
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import rlax
|
||||
|
||||
from tandem_dqn import losses
|
||||
from tandem_dqn import parts
|
||||
from tandem_dqn import processors
|
||||
from tandem_dqn import replay as replay_lib
|
||||
|
||||
|
||||
class TandemTuple(typing.NamedTuple):
|
||||
active: Any
|
||||
passive: Any
|
||||
|
||||
|
||||
def tandem_map(fn: Callable[..., Any], *args):
|
||||
return TandemTuple(
|
||||
active=fn(*[a.active for a in args]),
|
||||
passive=fn(*[a.passive for a in args]))
|
||||
|
||||
|
||||
def replace_module_params(source, target, modules):
|
||||
"""Replace selected module params in target by corresponding source values."""
|
||||
source, _ = hk.data_structures.partition(
|
||||
lambda module, name, value: module in modules,
|
||||
source)
|
||||
return hk.data_structures.merge(target, source)
|
||||
|
||||
|
||||
class TandemDqn(parts.Agent):
|
||||
"""Tandem DQN agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocessor: processors.Processor,
|
||||
sample_network_input: jnp.ndarray,
|
||||
network: TandemTuple,
|
||||
optimizer: TandemTuple,
|
||||
loss: TandemTuple,
|
||||
transition_accumulator: Any,
|
||||
replay: replay_lib.TransitionReplay,
|
||||
batch_size: int,
|
||||
exploration_epsilon: Callable[[int], float],
|
||||
min_replay_capacity_fraction: float,
|
||||
learn_period: int,
|
||||
target_network_update_period: int,
|
||||
tied_layers: Set[str],
|
||||
rng_key: parts.PRNGKey,
|
||||
):
|
||||
self._preprocessor = preprocessor
|
||||
self._replay = replay
|
||||
self._transition_accumulator = transition_accumulator
|
||||
self._batch_size = batch_size
|
||||
self._exploration_epsilon = exploration_epsilon
|
||||
self._min_replay_capacity = min_replay_capacity_fraction * replay.capacity
|
||||
self._learn_period = learn_period
|
||||
self._target_network_update_period = target_network_update_period
|
||||
|
||||
# Initialize network parameters and optimizer.
|
||||
self._rng_key, network_rng_key_active, network_rng_key_passive = (
|
||||
jax.random.split(rng_key, 3))
|
||||
active_params = network.active.init(
|
||||
network_rng_key_active, sample_network_input[None, ...])
|
||||
passive_params = network.passive.init(
|
||||
network_rng_key_passive, sample_network_input[None, ...])
|
||||
self._online_params = TandemTuple(
|
||||
active=active_params, passive=passive_params)
|
||||
self._target_params = self._online_params
|
||||
self._opt_state = tandem_map(
|
||||
lambda optim, params: optim.init(params),
|
||||
optimizer, self._online_params)
|
||||
|
||||
# Other agent state: last action, frame count, etc.
|
||||
self._action = None
|
||||
self._frame_t = -1 # Current frame index.
|
||||
|
||||
# Stats.
|
||||
stats = [
|
||||
'loss_active',
|
||||
'loss_passive',
|
||||
'frac_diff_argmax',
|
||||
'mc_error_active',
|
||||
'mc_error_passive',
|
||||
'mc_error_abs_active',
|
||||
'mc_error_abs_passive',
|
||||
]
|
||||
self._statistics = {k: np.nan for k in stats}
|
||||
|
||||
# Define jitted loss, update, and policy functions here instead of as
|
||||
# class methods, to emphasize that these are meant to be pure functions
|
||||
# and should not access the agent object's state via `self`.
|
||||
|
||||
def network_outputs(rng_key, online_params, target_params, transitions):
|
||||
"""Compute all potentially needed outputs of active and passive net."""
|
||||
_, *apply_keys = jax.random.split(rng_key, 4)
|
||||
outputs_tm1 = tandem_map(
|
||||
lambda net, param: net.apply(param, apply_keys[0], transitions.s_tm1),
|
||||
network, online_params)
|
||||
outputs_t = tandem_map(
|
||||
lambda net, param: net.apply(param, apply_keys[1], transitions.s_t),
|
||||
network, online_params)
|
||||
outputs_target_t = tandem_map(
|
||||
lambda net, param: net.apply(param, apply_keys[2], transitions.s_t),
|
||||
network, target_params)
|
||||
return outputs_tm1, outputs_t, outputs_target_t
|
||||
|
||||
# Helper functions to define active and passive losses.
|
||||
# Active and passive losses are allowed to depend on all active and passive
|
||||
# outputs, but stop-gradient is used to prevent gradients from flowing
|
||||
# from active loss to passive network params and vice versa.
|
||||
def sg_active(x):
|
||||
return TandemTuple(
|
||||
active=jax.lax.stop_gradient(x.active), passive=x.passive)
|
||||
|
||||
def sg_passive(x):
|
||||
return TandemTuple(
|
||||
active=x.active, passive=jax.lax.stop_gradient(x.passive))
|
||||
|
||||
def compute_loss(online_params, target_params, transitions, rng_key):
|
||||
rng_key, apply_key = jax.random.split(rng_key)
|
||||
outputs_tm1, outputs_t, outputs_target_t = network_outputs(
|
||||
apply_key, online_params, target_params, transitions)
|
||||
|
||||
_, loss_key_active, loss_key_passive = jax.random.split(rng_key, 3)
|
||||
loss_active = loss.active(
|
||||
sg_passive(outputs_tm1), sg_passive(outputs_t), outputs_target_t,
|
||||
transitions, loss_key_active)
|
||||
loss_passive = loss.passive(
|
||||
sg_active(outputs_tm1), sg_active(outputs_t), outputs_target_t,
|
||||
transitions, loss_key_passive)
|
||||
|
||||
# Logging stuff.
|
||||
a_tm1 = transitions.a_tm1
|
||||
mc_return_tm1 = transitions.mc_return_tm1
|
||||
q_values = TandemTuple(
|
||||
active=outputs_tm1.active.q_values,
|
||||
passive=outputs_tm1.passive.q_values)
|
||||
mc_error = jax.tree_map(
|
||||
lambda q: losses.batch_mc_learning(q, a_tm1, mc_return_tm1),
|
||||
q_values)
|
||||
mc_error_abs = jax.tree_map(jnp.abs, mc_error)
|
||||
q_argmax = jax.tree_map(lambda q: jnp.argmax(q, axis=-1), q_values)
|
||||
argmax_diff = jnp.not_equal(q_argmax.active, q_argmax.passive)
|
||||
|
||||
batch_mean = lambda x: jnp.mean(x, axis=0)
|
||||
logs = {
|
||||
'loss_active': loss_active,
|
||||
'loss_passive': loss_passive
|
||||
}
|
||||
logs.update(jax.tree_map(batch_mean, {
|
||||
'frac_diff_argmax': argmax_diff,
|
||||
'mc_error_active': mc_error.active,
|
||||
'mc_error_passive': mc_error.passive,
|
||||
'mc_error_abs_active': mc_error_abs.active,
|
||||
'mc_error_abs_passive': mc_error_abs.passive,
|
||||
}))
|
||||
return loss_active + loss_passive, logs
|
||||
|
||||
def optim_update(optim, online_params, d_loss_d_params, opt_state):
|
||||
updates, new_opt_state = optim.update(d_loss_d_params, opt_state)
|
||||
new_online_params = optax.apply_updates(online_params, updates)
|
||||
return new_opt_state, new_online_params
|
||||
|
||||
def compute_loss_grad(rng_key, online_params, target_params, transitions):
|
||||
rng_key, grad_key = jax.random.split(rng_key)
|
||||
(_, logs), d_loss_d_params = jax.value_and_grad(
|
||||
compute_loss, has_aux=True)(
|
||||
online_params, target_params, transitions, grad_key)
|
||||
return rng_key, logs, d_loss_d_params
|
||||
|
||||
def update_active(rng_key, opt_state, online_params, target_params,
|
||||
transitions):
|
||||
"""Applies learning update for active network only."""
|
||||
rng_key, logs, d_loss_d_params = compute_loss_grad(
|
||||
rng_key, online_params, target_params, transitions)
|
||||
new_opt_state_active, new_online_params_active = optim_update(
|
||||
optimizer.active, online_params.active, d_loss_d_params.active,
|
||||
opt_state.active)
|
||||
new_opt_state = opt_state._replace(
|
||||
active=new_opt_state_active)
|
||||
new_online_params = online_params._replace(
|
||||
active=new_online_params_active)
|
||||
return rng_key, new_opt_state, new_online_params, logs
|
||||
|
||||
self._update_active = jax.jit(update_active)
|
||||
|
||||
def update_passive(rng_key, opt_state, online_params, target_params,
|
||||
transitions):
|
||||
"""Applies learning update for passive network only."""
|
||||
rng_key, logs, d_loss_d_params = compute_loss_grad(
|
||||
rng_key, online_params, target_params, transitions)
|
||||
new_opt_state_passive, new_online_params_passive = optim_update(
|
||||
optimizer.passive, online_params.passive, d_loss_d_params.passive,
|
||||
opt_state.passive)
|
||||
new_opt_state = opt_state._replace(
|
||||
passive=new_opt_state_passive)
|
||||
new_online_params = online_params._replace(
|
||||
passive=new_online_params_passive)
|
||||
return rng_key, new_opt_state, new_online_params, logs
|
||||
|
||||
self._update_passive = jax.jit(update_passive)
|
||||
|
||||
def update_active_passive(rng_key, opt_state, online_params,
|
||||
target_params, transitions):
|
||||
"""Applies learning update for both active & passive networks."""
|
||||
rng_key, logs, d_loss_d_params = compute_loss_grad(
|
||||
rng_key, online_params, target_params, transitions)
|
||||
|
||||
new_opt_state_active, new_online_params_active = optim_update(
|
||||
optimizer.active, online_params.active, d_loss_d_params.active,
|
||||
opt_state.active)
|
||||
new_opt_state_passive, new_online_params_passive = optim_update(
|
||||
optimizer.passive, online_params.passive, d_loss_d_params.passive,
|
||||
opt_state.passive)
|
||||
new_opt_state = TandemTuple(active=new_opt_state_active,
|
||||
passive=new_opt_state_passive)
|
||||
new_online_params = TandemTuple(active=new_online_params_active,
|
||||
passive=new_online_params_passive)
|
||||
return rng_key, new_opt_state, new_online_params, logs
|
||||
|
||||
self._update_active_passive = jax.jit(update_active_passive)
|
||||
|
||||
self._update = None # set_training_mode needs to be called to set this.
|
||||
|
||||
def sync_tied_layers(online_params):
|
||||
"""Set tied layer params of passive to respective values of active."""
|
||||
new_online_params_passive = replace_module_params(
|
||||
source=online_params.active, target=online_params.passive,
|
||||
modules=tied_layers)
|
||||
return online_params._replace(passive=new_online_params_passive)
|
||||
|
||||
self._sync_tied_layers = jax.jit(sync_tied_layers)
|
||||
|
||||
def select_action(rng_key, network_params, s_t, exploration_epsilon):
|
||||
"""Samples action from eps-greedy policy wrt Q-values at given state."""
|
||||
rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
|
||||
q_t = network.active.apply(network_params, apply_key,
|
||||
s_t[None, ...]).q_values[0]
|
||||
a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon)
|
||||
return rng_key, a_t
|
||||
|
||||
self._select_action = jax.jit(select_action)
|
||||
|
||||
def step(self, timestep: dm_env.TimeStep) -> parts.Action:
|
||||
"""Selects action given timestep and potentially learns."""
|
||||
self._frame_t += 1
|
||||
|
||||
timestep = self._preprocessor(timestep)
|
||||
|
||||
if timestep is None: # Repeat action.
|
||||
action = self._action
|
||||
else:
|
||||
action = self._action = self._act(timestep)
|
||||
|
||||
for transition in self._transition_accumulator.step(timestep, action):
|
||||
self._replay.add(transition)
|
||||
|
||||
if self._replay.size < self._min_replay_capacity:
|
||||
return action
|
||||
|
||||
if self._frame_t % self._learn_period == 0:
|
||||
self._learn()
|
||||
|
||||
if self._frame_t % self._target_network_update_period == 0:
|
||||
self._target_params = self._online_params
|
||||
|
||||
return action
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the agent's episodic state such as frame stack and action repeat.
|
||||
|
||||
This method should be called at the beginning of every episode.
|
||||
"""
|
||||
self._transition_accumulator.reset()
|
||||
processors.reset(self._preprocessor)
|
||||
self._action = None
|
||||
|
||||
def _act(self, timestep) -> parts.Action:
|
||||
"""Selects action given timestep, according to epsilon-greedy policy."""
|
||||
s_t = timestep.observation
|
||||
network_params = self._online_params.active
|
||||
self._rng_key, a_t = self._select_action(
|
||||
self._rng_key, network_params, s_t, self.exploration_epsilon)
|
||||
return parts.Action(jax.device_get(a_t))
|
||||
|
||||
def _learn(self) -> None:
|
||||
"""Samples a batch of transitions from replay and learns from it."""
|
||||
logging.log_first_n(logging.INFO, 'Begin learning', 1)
|
||||
transitions = self._replay.sample(self._batch_size)
|
||||
self._rng_key, self._opt_state, self._online_params, logs = self._update(
|
||||
self._rng_key,
|
||||
self._opt_state,
|
||||
self._online_params,
|
||||
self._target_params,
|
||||
transitions,
|
||||
)
|
||||
self._online_params = self._sync_tied_layers(self._online_params)
|
||||
self._statistics.update(jax.device_get(logs))
|
||||
|
||||
def set_training_mode(self, mode: str):
|
||||
"""Sets training mode to one of 'active', 'passive', or 'active_passive'."""
|
||||
if mode == 'active':
|
||||
self._update = self._update_active
|
||||
elif mode == 'passive':
|
||||
self._update = self._update_passive
|
||||
elif mode == 'active_passive':
|
||||
self._update = self._update_active_passive
|
||||
|
||||
@property
|
||||
def online_params(self) -> TandemTuple:
|
||||
"""Returns current parameters of Q-network."""
|
||||
return self._online_params
|
||||
|
||||
@property
|
||||
def statistics(self) -> Mapping[Text, float]:
|
||||
"""Returns current agent statistics as a dictionary."""
|
||||
# Check for DeviceArrays in values as this can be very slow.
|
||||
assert all(
|
||||
not isinstance(x, jnp.DeviceArray) for x in self._statistics.values())
|
||||
return self._statistics
|
||||
|
||||
@property
|
||||
def exploration_epsilon(self) -> float:
|
||||
"""Returns epsilon value currently used by (eps-greedy) behavior policy."""
|
||||
return self._exploration_epsilon(self._frame_t)
|
||||
|
||||
def get_state(self) -> Mapping[Text, Any]:
|
||||
"""Retrieves agent state as a dictionary (e.g. for serialization)."""
|
||||
state = {
|
||||
'rng_key': self._rng_key,
|
||||
'frame_t': self._frame_t,
|
||||
'opt_state_active': self._opt_state.active,
|
||||
'online_params_active': self._online_params.active,
|
||||
'target_params_active': self._target_params.active,
|
||||
'opt_state_passive': self._opt_state.passive,
|
||||
'online_params_passive': self._online_params.passive,
|
||||
'target_params_passive': self._target_params.passive,
|
||||
'replay': self._replay.get_state(),
|
||||
}
|
||||
return state
|
||||
|
||||
def set_state(self, state: Mapping[Text, Any]) -> None:
|
||||
"""Sets agent state from a (potentially de-serialized) dictionary."""
|
||||
self._rng_key = state['rng_key']
|
||||
self._frame_t = state['frame_t']
|
||||
self._opt_state = TandemTuple(
|
||||
active=jax.device_put(state['opt_state_active']),
|
||||
passive=jax.device_put(state['opt_state_passive']))
|
||||
self._online_params = TandemTuple(
|
||||
active=jax.device_put(state['online_params_active']),
|
||||
passive=jax.device_put(state['online_params_passive']))
|
||||
self._target_params = TandemTuple(
|
||||
active=jax.device_put(state['target_params_active']),
|
||||
passive=jax.device_put(state['target_params_passive']))
|
||||
self._replay.set_state(state['replay'])
|
||||
@@ -0,0 +1,110 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities to compute human-normalized Atari scores.
|
||||
|
||||
The data used in this module is human and random performance data on Atari-57.
|
||||
It comprises of evaluation scores (undiscounted returns), each averaged
|
||||
over at least 3 episode runs, on each of the 57 Atari games. Each episode begins
|
||||
with the environment already stepped with a uniform random number (between 1 and
|
||||
30 inclusive) of noop actions.
|
||||
|
||||
The two agents are:
|
||||
* 'random' (agent choosing its actions uniformly randomly on each step)
|
||||
* 'human' (professional human game tester)
|
||||
|
||||
Scores are obtained by averaging returns over the episodes played by each agent,
|
||||
with episode length capped to 108,000 frames (i.e. timeout after 30 minutes).
|
||||
|
||||
The term 'human-normalized' here means a linear per-game transformation of
|
||||
a game score in such a way that 0 corresponds to random performance and 1
|
||||
corresponds to human performance.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
# Game: score-tuple dictionary. Each score tuple contains
|
||||
# 0: score random (float) and 1: score human (float).
|
||||
_ATARI_DATA = {
|
||||
'alien': (227.8, 7127.7),
|
||||
'amidar': (5.8, 1719.5),
|
||||
'assault': (222.4, 742.0),
|
||||
'asterix': (210.0, 8503.3),
|
||||
'asteroids': (719.1, 47388.7),
|
||||
'atlantis': (12850.0, 29028.1),
|
||||
'bank_heist': (14.2, 753.1),
|
||||
'battle_zone': (2360.0, 37187.5),
|
||||
'beam_rider': (363.9, 16926.5),
|
||||
'berzerk': (123.7, 2630.4),
|
||||
'bowling': (23.1, 160.7),
|
||||
'boxing': (0.1, 12.1),
|
||||
'breakout': (1.7, 30.5),
|
||||
'centipede': (2090.9, 12017.0),
|
||||
'chopper_command': (811.0, 7387.8),
|
||||
'crazy_climber': (10780.5, 35829.4),
|
||||
'defender': (2874.5, 18688.9),
|
||||
'demon_attack': (152.1, 1971.0),
|
||||
'double_dunk': (-18.6, -16.4),
|
||||
'enduro': (0.0, 860.5),
|
||||
'fishing_derby': (-91.7, -38.7),
|
||||
'freeway': (0.0, 29.6),
|
||||
'frostbite': (65.2, 4334.7),
|
||||
'gopher': (257.6, 2412.5),
|
||||
'gravitar': (173.0, 3351.4),
|
||||
'hero': (1027.0, 30826.4),
|
||||
'ice_hockey': (-11.2, 0.9),
|
||||
'jamesbond': (29.0, 302.8),
|
||||
'kangaroo': (52.0, 3035.0),
|
||||
'krull': (1598.0, 2665.5),
|
||||
'kung_fu_master': (258.5, 22736.3),
|
||||
'montezuma_revenge': (0.0, 4753.3),
|
||||
'ms_pacman': (307.3, 6951.6),
|
||||
'name_this_game': (2292.3, 8049.0),
|
||||
'phoenix': (761.4, 7242.6),
|
||||
'pitfall': (-229.4, 6463.7),
|
||||
'pong': (-20.7, 14.6),
|
||||
'private_eye': (24.9, 69571.3),
|
||||
'qbert': (163.9, 13455.0),
|
||||
'riverraid': (1338.5, 17118.0),
|
||||
'road_runner': (11.5, 7845.0),
|
||||
'robotank': (2.2, 11.9),
|
||||
'seaquest': (68.4, 42054.7),
|
||||
'skiing': (-17098.1, -4336.9),
|
||||
'solaris': (1236.3, 12326.7),
|
||||
'space_invaders': (148.0, 1668.7),
|
||||
'star_gunner': (664.0, 10250.0),
|
||||
'surround': (-10.0, 6.5),
|
||||
'tennis': (-23.8, -8.3),
|
||||
'time_pilot': (3568.0, 5229.2),
|
||||
'tutankham': (11.4, 167.6),
|
||||
'up_n_down': (533.4, 11693.2),
|
||||
'venture': (0.0, 1187.5),
|
||||
# Note the random agent score on Video Pinball is sometimes greater than the
|
||||
# human score under other evaluation methods.
|
||||
'video_pinball': (16256.9, 17667.9),
|
||||
'wizard_of_wor': (563.5, 4756.5),
|
||||
'yars_revenge': (3092.9, 54576.9),
|
||||
'zaxxon': (32.5, 9173.3),
|
||||
}
|
||||
|
||||
_RANDOM_COL = 0
|
||||
_HUMAN_COL = 1
|
||||
|
||||
ATARI_GAMES = tuple(sorted(_ATARI_DATA.keys()))
|
||||
|
||||
|
||||
def get_human_normalized_score(game: str, raw_score: float) -> float:
|
||||
"""Converts game score to human-normalized score."""
|
||||
game_scores = _ATARI_DATA.get(game, (math.nan, math.nan))
|
||||
random, human = game_scores[_RANDOM_COL], game_scores[_HUMAN_COL]
|
||||
return (raw_score - random) / (human - random)
|
||||
@@ -0,0 +1,217 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""dm_env environment wrapper around Gym Atari configured to be like Xitari.
|
||||
|
||||
Gym Atari is built on the Arcade Learning Environment (ALE), whereas Xitari is
|
||||
an old fork of the ALE.
|
||||
"""
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import atari_py # pylint: disable=unused-import for gym to load Atari games.
|
||||
import dm_env
|
||||
from dm_env import specs
|
||||
import gym
|
||||
import numpy as np
|
||||
|
||||
from tandem_dqn import atari_data
|
||||
|
||||
_GYM_ID_SUFFIX = '-xitari-v1'
|
||||
_SA_SUFFIX = '-sa'
|
||||
|
||||
|
||||
def _game_id(game, sticky_actions):
|
||||
return game + (_SA_SUFFIX if sticky_actions else '') + _GYM_ID_SUFFIX
|
||||
|
||||
|
||||
def _register_atari_environments():
|
||||
"""Registers Atari environments in Gym to be as similar to Xitari as possible.
|
||||
|
||||
Main difference from PongNoFrameSkip-v4, etc. is max_episode_steps is unset
|
||||
and only the usual 57 Atari games are registered.
|
||||
|
||||
Additionally, sticky-actions variants of the environments are registered
|
||||
with an '-sa' suffix.
|
||||
"""
|
||||
for sticky_actions in [False, True]:
|
||||
for game in atari_data.ATARI_GAMES:
|
||||
repeat_action_probability = 0.25 if sticky_actions else 0.0
|
||||
gym.envs.registration.register(
|
||||
id=_game_id(game, sticky_actions),
|
||||
entry_point='gym.envs.atari:AtariEnv',
|
||||
kwargs={ # Explicitly set all known arguments.
|
||||
'game': game,
|
||||
'mode': None, # Not necessarily the same as 0.
|
||||
'difficulty': None, # Not necessarily the same as 0.
|
||||
'obs_type': 'image',
|
||||
'frameskip': 1, # Get every frame.
|
||||
'repeat_action_probability': repeat_action_probability,
|
||||
'full_action_space': False,
|
||||
},
|
||||
max_episode_steps=None, # No time limit, handled in run loop.
|
||||
nondeterministic=False, # Xitari is deterministic.
|
||||
)
|
||||
|
||||
|
||||
_register_atari_environments()
|
||||
|
||||
|
||||
class GymAtari(dm_env.Environment):
|
||||
"""Gym Atari with a `dm_env.Environment` interface."""
|
||||
|
||||
def __init__(self, game, sticky_actions, seed):
|
||||
self._gym_env = gym.make(_game_id(game, sticky_actions))
|
||||
self._gym_env.seed(seed)
|
||||
self._start_of_episode = True
|
||||
|
||||
def reset(self) -> dm_env.TimeStep:
|
||||
"""Resets the environment and starts a new episode."""
|
||||
observation = self._gym_env.reset()
|
||||
lives = np.int32(self._gym_env.ale.lives())
|
||||
timestep = dm_env.restart((observation, lives))
|
||||
self._start_of_episode = False
|
||||
return timestep
|
||||
|
||||
def step(self, action: np.int32) -> dm_env.TimeStep:
|
||||
"""Updates the environment given an action and returns a timestep."""
|
||||
# If the previous timestep was LAST then we call reset() on the Gym
|
||||
# environment, otherwise step(). Although Gym environments allow you to step
|
||||
# through episode boundaries (similar to dm_env) they emit a warning.
|
||||
if self._start_of_episode:
|
||||
step_type = dm_env.StepType.FIRST
|
||||
observation = self._gym_env.reset()
|
||||
discount = None
|
||||
reward = None
|
||||
done = False
|
||||
else:
|
||||
observation, reward, done, info = self._gym_env.step(action)
|
||||
if done:
|
||||
assert 'TimeLimit.truncated' not in info, 'Should never truncate.'
|
||||
step_type = dm_env.StepType.LAST
|
||||
discount = 0.
|
||||
else:
|
||||
step_type = dm_env.StepType.MID
|
||||
discount = 1.
|
||||
|
||||
lives = np.int32(self._gym_env.ale.lives())
|
||||
timestep = dm_env.TimeStep(
|
||||
step_type=step_type,
|
||||
observation=(observation, lives),
|
||||
reward=reward,
|
||||
discount=discount,
|
||||
)
|
||||
self._start_of_episode = done
|
||||
return timestep
|
||||
|
||||
def observation_spec(self) -> Tuple[specs.Array, specs.Array]:
|
||||
space = self._gym_env.observation_space
|
||||
return (specs.Array(shape=space.shape, dtype=space.dtype, name='rgb'),
|
||||
specs.Array(shape=(), dtype=np.int32, name='lives'))
|
||||
|
||||
def action_spec(self) -> specs.DiscreteArray:
|
||||
space = self._gym_env.action_space
|
||||
return specs.DiscreteArray(
|
||||
num_values=space.n, dtype=np.int32, name='action')
|
||||
|
||||
def close(self):
|
||||
self._gym_env.close()
|
||||
|
||||
|
||||
class RandomNoopsEnvironmentWrapper(dm_env.Environment):
|
||||
"""Adds a random number of noop actions at the beginning of each episode."""
|
||||
|
||||
def __init__(self,
|
||||
environment: dm_env.Environment,
|
||||
max_noop_steps: int,
|
||||
min_noop_steps: int = 0,
|
||||
noop_action: int = 0,
|
||||
seed: Optional[int] = None):
|
||||
"""Initializes the random noops environment wrapper."""
|
||||
self._environment = environment
|
||||
if max_noop_steps < min_noop_steps:
|
||||
raise ValueError('max_noop_steps must be greater or equal min_noop_steps')
|
||||
self._min_noop_steps = min_noop_steps
|
||||
self._max_noop_steps = max_noop_steps
|
||||
self._noop_action = noop_action
|
||||
self._rng = np.random.RandomState(seed)
|
||||
|
||||
def reset(self):
|
||||
"""Begins new episode.
|
||||
|
||||
This method resets the wrapped environment and applies a random number
|
||||
of noop actions before returning the last resulting observation
|
||||
as the first episode timestep. Intermediate timesteps emitted by the inner
|
||||
environment (including all rewards and discounts) are discarded.
|
||||
|
||||
Returns:
|
||||
First episode timestep corresponding to the timestep after a random number
|
||||
of noop actions are applied to the inner environment.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if an episode end occurs while the inner environment
|
||||
is being stepped through with the noop action.
|
||||
"""
|
||||
return self._apply_random_noops(initial_timestep=self._environment.reset())
|
||||
|
||||
def step(self, action):
|
||||
"""Steps environment given action.
|
||||
|
||||
If beginning a new episode then random noops are applied as in `reset()`.
|
||||
|
||||
Args:
|
||||
action: action to pass to environment conforming to action spec.
|
||||
|
||||
Returns:
|
||||
`Timestep` from the inner environment unless beginning a new episode, in
|
||||
which case this is the timestep after a random number of noop actions
|
||||
are applied to the inner environment.
|
||||
"""
|
||||
timestep = self._environment.step(action)
|
||||
if timestep.first():
|
||||
return self._apply_random_noops(initial_timestep=timestep)
|
||||
else:
|
||||
return timestep
|
||||
|
||||
def _apply_random_noops(self, initial_timestep):
|
||||
assert initial_timestep.first()
|
||||
num_steps = self._rng.randint(self._min_noop_steps,
|
||||
self._max_noop_steps + 1)
|
||||
timestep = initial_timestep
|
||||
for _ in range(num_steps):
|
||||
timestep = self._environment.step(self._noop_action)
|
||||
if timestep.last():
|
||||
raise RuntimeError('Episode ended while applying %s noop actions.' %
|
||||
num_steps)
|
||||
|
||||
# We make sure to return a FIRST timestep, i.e. discard rewards & discounts.
|
||||
return dm_env.restart(timestep.observation)
|
||||
|
||||
## All methods except for reset and step redirect to the underlying env.
|
||||
|
||||
def observation_spec(self):
|
||||
return self._environment.observation_spec()
|
||||
|
||||
def action_spec(self):
|
||||
return self._environment.action_spec()
|
||||
|
||||
def reward_spec(self):
|
||||
return self._environment.reward_spec()
|
||||
|
||||
def discount_spec(self):
|
||||
return self._environment.discount_spec()
|
||||
|
||||
def close(self):
|
||||
return self._environment.close()
|
||||
@@ -0,0 +1,190 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Losses for TandemDQN."""
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
import chex
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import rlax
|
||||
|
||||
from tandem_dqn import networks
|
||||
|
||||
# Batch variants of double_q_learning and SARSA.
|
||||
batch_double_q_learning = jax.vmap(rlax.double_q_learning)
|
||||
batch_sarsa_learning = jax.vmap(rlax.sarsa)
|
||||
|
||||
# Batch variant of quantile_q_learning with fixed tau input across batch.
|
||||
batch_quantile_q_learning = jax.vmap(
|
||||
rlax.quantile_q_learning, in_axes=(0, None, 0, 0, 0, 0, 0, None))
|
||||
|
||||
|
||||
def _mc_learning(
|
||||
q_tm1: chex.Array,
|
||||
a_tm1: chex.Numeric,
|
||||
mc_return_tm1: chex.Array,
|
||||
) -> chex.Numeric:
|
||||
"""Calculates the MC return error."""
|
||||
chex.assert_rank([q_tm1, a_tm1], [1, 0])
|
||||
chex.assert_type([q_tm1, a_tm1], [float, int])
|
||||
return mc_return_tm1 - q_tm1[a_tm1]
|
||||
|
||||
# Batch variant of MC learning.
|
||||
batch_mc_learning = jax.vmap(_mc_learning)
|
||||
|
||||
|
||||
def _qr_loss(q_tm1, q_t, q_target_t, transitions, rng_key):
|
||||
"""Calculates QR-Learning loss from network outputs and transitions."""
|
||||
del q_t, rng_key # Unused.
|
||||
# Compute Q value distributions.
|
||||
huber_param = 1.
|
||||
quantiles = networks.make_quantiles()
|
||||
losses = batch_quantile_q_learning(
|
||||
q_tm1.q_dist,
|
||||
quantiles,
|
||||
transitions.a_tm1,
|
||||
transitions.r_t,
|
||||
transitions.discount_t,
|
||||
q_target_t.q_dist, # No double Q-learning here.
|
||||
q_target_t.q_dist,
|
||||
huber_param,
|
||||
)
|
||||
loss = jnp.mean(losses)
|
||||
return loss
|
||||
|
||||
|
||||
def _sarsa_loss(q_tm1, q_t, transitions, rng_key):
|
||||
"""Calculates SARSA loss from network outputs and transitions."""
|
||||
del rng_key # Unused.
|
||||
grad_error_bound = 1. / 32
|
||||
td_errors = batch_sarsa_learning(
|
||||
q_tm1.q_values,
|
||||
transitions.a_tm1,
|
||||
transitions.r_t,
|
||||
transitions.discount_t,
|
||||
q_t.q_values,
|
||||
transitions.a_t
|
||||
)
|
||||
td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound)
|
||||
losses = rlax.l2_loss(td_errors)
|
||||
loss = jnp.mean(losses)
|
||||
return loss
|
||||
|
||||
|
||||
def _mc_loss(q_tm1, transitions, rng_key):
|
||||
"""Calculates Monte-Carlo return loss, i.e. regression towards MC return."""
|
||||
del rng_key # Unused.
|
||||
errors = batch_mc_learning(q_tm1.q_values, transitions.a_tm1,
|
||||
transitions.mc_return_tm1)
|
||||
loss = jnp.mean(rlax.l2_loss(errors))
|
||||
return loss
|
||||
|
||||
|
||||
def _double_q_loss(q_tm1, q_t, q_target_t, transitions, rng_key):
|
||||
"""Calculates Double Q-Learning loss from network outputs and transitions."""
|
||||
del rng_key # Unused.
|
||||
grad_error_bound = 1. / 32
|
||||
td_errors = batch_double_q_learning(
|
||||
q_tm1.q_values,
|
||||
transitions.a_tm1,
|
||||
transitions.r_t,
|
||||
transitions.discount_t,
|
||||
q_target_t.q_values,
|
||||
q_t.q_values,
|
||||
)
|
||||
td_errors = rlax.clip_gradient(td_errors, -grad_error_bound, grad_error_bound)
|
||||
losses = rlax.l2_loss(td_errors)
|
||||
loss = jnp.mean(losses)
|
||||
return loss
|
||||
|
||||
|
||||
def _q_regression_loss(q_tm1, q_tm1_target):
|
||||
"""Loss for regression of all action values towards targets."""
|
||||
errors = q_tm1.q_values - jax.lax.stop_gradient(q_tm1_target.q_values)
|
||||
loss = jnp.mean(rlax.l2_loss(errors))
|
||||
return loss
|
||||
|
||||
|
||||
def make_loss_fn(loss_type: str, active: bool) -> Callable[..., Any]:
|
||||
"""Create active or passive loss function of given type."""
|
||||
|
||||
if active:
|
||||
primary = lambda x: x.active
|
||||
secondary = lambda x: x.passive
|
||||
else:
|
||||
primary = lambda x: x.passive
|
||||
secondary = lambda x: x.active
|
||||
|
||||
def sarsa_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
|
||||
"""SARSA loss using own networks."""
|
||||
del q_t # Unused.
|
||||
return _sarsa_loss(primary(q_tm1), primary(q_target_t), transitions,
|
||||
rng_key)
|
||||
|
||||
def mc_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
|
||||
"""MonteCarlo loss."""
|
||||
del q_t, q_target_t
|
||||
return _mc_loss(primary(q_tm1), transitions, rng_key)
|
||||
|
||||
def double_q_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
|
||||
"""Regular DoubleQ loss using own networks."""
|
||||
return _double_q_loss(primary(q_tm1), primary(q_t), primary(q_target_t),
|
||||
transitions, rng_key)
|
||||
|
||||
def double_q_loss_v_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
|
||||
"""DoubleQ loss using other network's (target) value function."""
|
||||
return _double_q_loss(primary(q_tm1), primary(q_t), secondary(q_target_t),
|
||||
transitions, rng_key)
|
||||
|
||||
def double_q_loss_p_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
|
||||
"""DoubleQ loss using other network's (online) argmax policy."""
|
||||
return _double_q_loss(primary(q_tm1), secondary(q_t), primary(q_target_t),
|
||||
transitions, rng_key)
|
||||
|
||||
def double_q_loss_pv_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
|
||||
"""DoubleQ loss using other network's argmax policy & target value fn."""
|
||||
return _double_q_loss(primary(q_tm1), secondary(q_t), secondary(q_target_t),
|
||||
transitions, rng_key)
|
||||
|
||||
# Pure regression.
|
||||
def q_regression_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
|
||||
"""Pure regression of q_tm1(self) towards q_tm1(other)."""
|
||||
del q_t, q_target_t, transitions, rng_key # Unused.
|
||||
return _q_regression_loss(primary(q_tm1), secondary(q_tm1))
|
||||
|
||||
# QR loss.
|
||||
def qr_loss_fn(q_tm1, q_t, q_target_t, transitions, rng_key):
|
||||
"""QR-Q loss using own networks."""
|
||||
return _qr_loss(primary(q_tm1), primary(q_t), primary(q_target_t),
|
||||
transitions, rng_key)
|
||||
|
||||
if loss_type == 'double_q':
|
||||
return double_q_loss_fn
|
||||
elif loss_type == 'sarsa':
|
||||
return sarsa_loss_fn
|
||||
elif loss_type == 'mc_return':
|
||||
return mc_loss_fn
|
||||
elif loss_type == 'double_q_v':
|
||||
return double_q_loss_v_fn
|
||||
elif loss_type == 'double_q_p':
|
||||
return double_q_loss_p_fn
|
||||
elif loss_type == 'double_q_pv':
|
||||
return double_q_loss_pv_fn
|
||||
elif loss_type == 'q_regression':
|
||||
return q_regression_loss_fn
|
||||
elif loss_type == 'qr':
|
||||
return qr_loss_fn
|
||||
else:
|
||||
raise ValueError('Unknown loss "{}"'.format(loss_type))
|
||||
@@ -0,0 +1,95 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for Tandem losses."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import chex
|
||||
import jax
|
||||
from jax.config import config
|
||||
import numpy as np
|
||||
|
||||
from tandem_dqn import agent
|
||||
from tandem_dqn import losses
|
||||
from tandem_dqn import networks
|
||||
from tandem_dqn import replay
|
||||
|
||||
|
||||
def make_tandem_qvals():
|
||||
return agent.TandemTuple(
|
||||
active=networks.QNetworkOutputs(3. * np.ones((3, 5), np.float32)),
|
||||
passive=networks.QNetworkOutputs(2. * np.ones((3, 5), np.float32))
|
||||
)
|
||||
|
||||
|
||||
def make_transition():
|
||||
return replay.Transition(
|
||||
s_tm1=np.zeros(3), a_tm1=np.ones(3, np.int32), r_t=5. * np.ones(3),
|
||||
discount_t=0.9 * np.ones(3), s_t=np.zeros(3))
|
||||
|
||||
|
||||
class DoubleQLossesTest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.qs = make_tandem_qvals()
|
||||
self.transition = make_transition()
|
||||
self.rng_key = jax.random.PRNGKey(42)
|
||||
|
||||
@chex.all_variants()
|
||||
@parameterized.parameters(
|
||||
('double_q',), ('double_q_v',), ('double_q_p',), ('double_q_pv',),
|
||||
('q_regression',),
|
||||
)
|
||||
def test_active_loss_gradients(self, loss_type):
|
||||
loss_fn = losses.make_loss_fn(loss_type, active=True)
|
||||
def fn(q_tm1, q_t, q_t_target, transition, rng_key):
|
||||
return loss_fn(q_tm1, q_t, q_t_target, transition, rng_key)
|
||||
grad_fn = self.variant(jax.grad(fn, argnums=(0, 1, 2)))
|
||||
|
||||
dldq_tm1, dldq_t, dldq_t_target = grad_fn(
|
||||
self.qs, self.qs, self.qs, self.transition, self.rng_key)
|
||||
# Assert that only active net gets nonzero gradients.
|
||||
self.assertGreater(np.sum(np.abs(dldq_tm1.active.q_values)), 0.)
|
||||
self.assertTrue(np.all(dldq_t.active.q_values == 0.))
|
||||
self.assertTrue(np.all(dldq_t_target.active.q_values == 0.))
|
||||
self.assertTrue(np.all(dldq_t.passive.q_values == 0.))
|
||||
self.assertTrue(np.all(dldq_tm1.passive.q_values == 0.))
|
||||
self.assertTrue(np.all(dldq_t_target.passive.q_values == 0.))
|
||||
|
||||
@chex.all_variants()
|
||||
@parameterized.parameters(
|
||||
('double_q',), ('double_q_v',), ('double_q_p',), ('double_q_pv',),
|
||||
('q_regression',),
|
||||
)
|
||||
def test_passive_loss_gradients(self, loss_type):
|
||||
loss_fn = losses.make_loss_fn(loss_type, active=False)
|
||||
def fn(q_tm1, q_t, q_t_target, transition, rng_key):
|
||||
return loss_fn(q_tm1, q_t, q_t_target, transition, rng_key)
|
||||
grad_fn = self.variant(jax.grad(fn, argnums=(0, 1, 2)))
|
||||
|
||||
dldq_tm1, dldq_t, dldq_t_target = grad_fn(
|
||||
self.qs, self.qs, self.qs, self.transition, self.rng_key)
|
||||
# Assert that only passive net gets nonzero gradients.
|
||||
self.assertGreater(np.sum(np.abs(dldq_tm1.passive.q_values)), 0.)
|
||||
self.assertTrue(np.all(dldq_t.passive.q_values == 0.))
|
||||
self.assertTrue(np.all(dldq_t_target.passive.q_values == 0.))
|
||||
self.assertTrue(np.all(dldq_t.active.q_values == 0.))
|
||||
self.assertTrue(np.all(dldq_tm1.active.q_values == 0.))
|
||||
self.assertTrue(np.all(dldq_t_target.active.q_values == 0.))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.update('jax_numpy_rank_promotion', 'raise')
|
||||
absltest.main()
|
||||
@@ -0,0 +1,216 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""DQN agent network components and implementation."""
|
||||
|
||||
import typing
|
||||
from typing import Any, Callable, Tuple, Union
|
||||
|
||||
import chex
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
Network = hk.Transformed
|
||||
Params = hk.Params
|
||||
NetworkFn = Callable[..., Any]
|
||||
|
||||
|
||||
class QNetworkOutputs(typing.NamedTuple):
|
||||
q_values: jnp.ndarray
|
||||
|
||||
|
||||
class QRNetworkOutputs(typing.NamedTuple):
|
||||
q_values: jnp.ndarray
|
||||
q_dist: jnp.ndarray
|
||||
|
||||
|
||||
NUM_QUANTILES = 201
|
||||
|
||||
|
||||
def _dqn_default_initializer(
|
||||
num_input_units: int) -> hk.initializers.Initializer:
|
||||
"""Default initialization scheme inherited from past implementations of DQN.
|
||||
|
||||
This scheme was historically used to initialize all weights and biases
|
||||
in convolutional and linear layers of DQN-type agents' networks.
|
||||
It initializes each weight as an independent uniform sample from [`-c`, `c`],
|
||||
where `c = 1 / np.sqrt(num_input_units)`, and `num_input_units` is the number
|
||||
of input units affecting a single output unit in the given layer, i.e. the
|
||||
total number of inputs in the case of linear (dense) layers, and
|
||||
`num_input_channels * kernel_width * kernel_height` in the case of
|
||||
convolutional layers.
|
||||
|
||||
Args:
|
||||
num_input_units: number of input units to a single output unit of the layer.
|
||||
|
||||
Returns:
|
||||
Haiku weight initializer.
|
||||
"""
|
||||
max_val = np.sqrt(1 / num_input_units)
|
||||
return hk.initializers.RandomUniform(-max_val, max_val)
|
||||
|
||||
|
||||
def make_quantiles():
|
||||
"""Quantiles for QR-DQN."""
|
||||
return (jnp.arange(0, NUM_QUANTILES) + 0.5) / float(NUM_QUANTILES)
|
||||
|
||||
|
||||
def conv(
|
||||
num_features: int,
|
||||
kernel_shape: Union[int, Tuple[int, int]],
|
||||
stride: Union[int, Tuple[int, int]],
|
||||
name=None,
|
||||
) -> NetworkFn:
|
||||
"""Convolutional layer with DQN's legacy weight initialization scheme."""
|
||||
|
||||
def net_fn(inputs):
|
||||
"""Function representing conv layer with DQN's legacy initialization."""
|
||||
num_input_units = inputs.shape[-1] * kernel_shape[0] * kernel_shape[1]
|
||||
initializer = _dqn_default_initializer(num_input_units)
|
||||
layer = hk.Conv2D(
|
||||
num_features,
|
||||
kernel_shape=kernel_shape,
|
||||
stride=stride,
|
||||
w_init=initializer,
|
||||
b_init=initializer,
|
||||
padding='VALID',
|
||||
name=name)
|
||||
return layer(inputs)
|
||||
|
||||
return net_fn
|
||||
|
||||
|
||||
def linear(num_outputs: int, with_bias=True, name=None) -> NetworkFn:
|
||||
"""Linear layer with DQN's legacy weight initialization scheme."""
|
||||
|
||||
def net_fn(inputs):
|
||||
"""Function representing linear layer with DQN's legacy initialization."""
|
||||
initializer = _dqn_default_initializer(inputs.shape[-1])
|
||||
layer = hk.Linear(
|
||||
num_outputs,
|
||||
with_bias=with_bias,
|
||||
w_init=initializer,
|
||||
b_init=initializer,
|
||||
name=name)
|
||||
return layer(inputs)
|
||||
|
||||
return net_fn
|
||||
|
||||
|
||||
def linear_with_shared_bias(num_outputs: int, name=None) -> NetworkFn:
|
||||
"""Linear layer with single shared bias instead of one bias per output."""
|
||||
|
||||
def layer_fn(inputs):
|
||||
"""Function representing a linear layer with single shared bias."""
|
||||
initializer = _dqn_default_initializer(inputs.shape[-1])
|
||||
bias_free_linear = hk.Linear(
|
||||
num_outputs, with_bias=False, w_init=initializer, name=name)
|
||||
linear_output = bias_free_linear(inputs)
|
||||
bias = hk.get_parameter('b', [1], inputs.dtype, init=initializer)
|
||||
bias = jnp.broadcast_to(bias, linear_output.shape)
|
||||
return linear_output + bias
|
||||
|
||||
return layer_fn
|
||||
|
||||
|
||||
def dqn_torso() -> NetworkFn:
|
||||
"""DQN convolutional torso.
|
||||
|
||||
Includes scaling from [`0`, `255`] (`uint8`) to [`0`, `1`] (`float32`)`.
|
||||
|
||||
Returns:
|
||||
Network function that `haiku.transform` can be called on.
|
||||
"""
|
||||
|
||||
def net_fn(inputs):
|
||||
"""Function representing convolutional torso for a DQN Q-network."""
|
||||
network = hk.Sequential([
|
||||
lambda x: x.astype(jnp.float32) / 255.,
|
||||
conv(32, kernel_shape=(8, 8), stride=(4, 4), name='conv1'),
|
||||
jax.nn.relu,
|
||||
conv(64, kernel_shape=(4, 4), stride=(2, 2), name='conv2'),
|
||||
jax.nn.relu,
|
||||
conv(64, kernel_shape=(3, 3), stride=(1, 1), name='conv3'),
|
||||
jax.nn.relu,
|
||||
hk.Flatten(),
|
||||
])
|
||||
return network(inputs)
|
||||
|
||||
return net_fn
|
||||
|
||||
|
||||
def dqn_value_head(num_actions: int, shared_bias: bool = False) -> NetworkFn:
|
||||
"""Regular DQN Q-value head with single hidden layer."""
|
||||
|
||||
last_layer = linear_with_shared_bias if shared_bias else linear
|
||||
|
||||
def net_fn(inputs):
|
||||
"""Function representing value head for a DQN Q-network."""
|
||||
network = hk.Sequential([
|
||||
linear(512, name='linear1'),
|
||||
jax.nn.relu,
|
||||
last_layer(num_actions, name='output'),
|
||||
])
|
||||
return network(inputs)
|
||||
|
||||
return net_fn
|
||||
|
||||
|
||||
def qr_atari_network(num_actions: int, quantiles: jnp.ndarray) -> NetworkFn:
|
||||
"""QR-DQN network, expects `uint8` input."""
|
||||
|
||||
chex.assert_rank(quantiles, 1)
|
||||
num_quantiles = len(quantiles)
|
||||
|
||||
def net_fn(inputs):
|
||||
"""Function representing QR-DQN Q-network."""
|
||||
network = hk.Sequential([
|
||||
dqn_torso(),
|
||||
dqn_value_head(num_quantiles * num_actions),
|
||||
])
|
||||
network_output = network(inputs)
|
||||
q_dist = jnp.reshape(network_output, (-1, num_quantiles, num_actions))
|
||||
q_values = jnp.mean(q_dist, axis=1)
|
||||
q_values = jax.lax.stop_gradient(q_values)
|
||||
return QRNetworkOutputs(q_dist=q_dist, q_values=q_values)
|
||||
|
||||
return net_fn
|
||||
|
||||
|
||||
def double_dqn_atari_network(num_actions: int) -> NetworkFn:
|
||||
"""DQN network with shared bias in final layer, expects `uint8` input."""
|
||||
|
||||
def net_fn(inputs):
|
||||
"""Function representing DQN Q-network with shared bias output layer."""
|
||||
network = hk.Sequential([
|
||||
dqn_torso(),
|
||||
dqn_value_head(num_actions, shared_bias=True),
|
||||
])
|
||||
return QNetworkOutputs(q_values=network(inputs))
|
||||
|
||||
return net_fn
|
||||
|
||||
|
||||
def make_network(network_type: str, num_actions: int) -> Network:
|
||||
"""Constructs network."""
|
||||
if network_type == 'double_q':
|
||||
network_fn = double_dqn_atari_network(num_actions)
|
||||
elif network_type == 'qr':
|
||||
quantiles = make_quantiles()
|
||||
network_fn = qr_atari_network(num_actions, quantiles)
|
||||
else:
|
||||
raise ValueError('Unknown network "{}"'.format(network_type))
|
||||
|
||||
return hk.transform(network_fn)
|
||||
@@ -0,0 +1,487 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Components for DQN."""
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import csv
|
||||
import os
|
||||
import timeit
|
||||
from typing import Any, Iterable, Mapping, Optional, Text, Tuple, Union
|
||||
|
||||
import dm_env
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import rlax
|
||||
|
||||
from tandem_dqn import networks
|
||||
from tandem_dqn import processors
|
||||
|
||||
Action = int
|
||||
Network = networks.Network
|
||||
NetworkParams = networks.Params
|
||||
PRNGKey = jnp.ndarray # A size 2 array.
|
||||
|
||||
|
||||
class Agent(abc.ABC):
|
||||
"""Agent interface."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def step(self, timestep: dm_env.TimeStep) -> Action:
|
||||
"""Selects action given timestep and potentially learns."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Resets the agent's episodic state such as frame stack and action repeat.
|
||||
|
||||
This method should be called at the beginning of every episode.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_state(self) -> Mapping[Text, Any]:
|
||||
"""Retrieves agent state as a dictionary (e.g. for serialization)."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_state(self, state: Mapping[Text, Any]) -> None:
|
||||
"""Sets agent state from a (potentially de-serialized) dictionary."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def statistics(self) -> Mapping[Text, float]:
|
||||
"""Returns current agent statistics as a dictionary."""
|
||||
|
||||
|
||||
def run_loop(
|
||||
agent: Agent,
|
||||
environment: dm_env.Environment,
|
||||
max_steps_per_episode: int = 0,
|
||||
yield_before_reset: bool = False,
|
||||
) -> Iterable[Tuple[dm_env.Environment, Optional[dm_env.TimeStep], Agent,
|
||||
Optional[Action]]]:
|
||||
"""Repeatedly alternates step calls on environment and agent.
|
||||
|
||||
At time `t`, `t + 1` environment timesteps and `t + 1` agent steps have been
|
||||
seen in the current episode. `t` resets to `0` for the next episode.
|
||||
|
||||
Args:
|
||||
agent: Agent to be run, has methods `step(timestep)` and `reset()`.
|
||||
environment: Environment to run, has methods `step(action)` and `reset()`.
|
||||
max_steps_per_episode: If positive, when time t reaches this value within an
|
||||
episode, the episode is truncated.
|
||||
yield_before_reset: Whether to additionally yield `(environment, None,
|
||||
agent, None)` before the agent and environment is reset at the start of
|
||||
each episode.
|
||||
|
||||
Yields:
|
||||
Tuple `(environment, timestep_t, agent, a_t)` where
|
||||
`a_t = agent.step(timestep_t)`.
|
||||
"""
|
||||
while True: # For each episode.
|
||||
if yield_before_reset:
|
||||
yield environment, None, agent, None,
|
||||
|
||||
t = 0
|
||||
agent.reset()
|
||||
timestep_t = environment.reset() # timestep_0.
|
||||
|
||||
while True: # For each step in the current episode.
|
||||
a_t = agent.step(timestep_t)
|
||||
yield environment, timestep_t, agent, a_t
|
||||
|
||||
# Update t after one environment step and agent step and relabel.
|
||||
t += 1
|
||||
a_tm1 = a_t
|
||||
timestep_t = environment.step(a_tm1)
|
||||
|
||||
if max_steps_per_episode > 0 and t >= max_steps_per_episode:
|
||||
assert t == max_steps_per_episode
|
||||
timestep_t = timestep_t._replace(step_type=dm_env.StepType.LAST)
|
||||
|
||||
if timestep_t.last():
|
||||
unused_a_t = agent.step(timestep_t) # Extra agent step, action ignored.
|
||||
yield environment, timestep_t, agent, None
|
||||
break
|
||||
|
||||
|
||||
def generate_statistics(
|
||||
trackers: Iterable[Any],
|
||||
timestep_action_sequence: Iterable[Tuple[dm_env.Environment,
|
||||
Optional[dm_env.TimeStep], Agent,
|
||||
Optional[Action]]]
|
||||
) -> Mapping[Text, Any]:
|
||||
"""Generates statistics from a sequence of timestep and actions."""
|
||||
# Only reset at the start, not between episodes.
|
||||
for tracker in trackers:
|
||||
tracker.reset()
|
||||
|
||||
for environment, timestep_t, agent, a_t in timestep_action_sequence:
|
||||
for tracker in trackers:
|
||||
tracker.step(environment, timestep_t, agent, a_t)
|
||||
|
||||
# Merge all statistics dictionaries into one.
|
||||
statistics_dicts = (tracker.get() for tracker in trackers)
|
||||
return dict(collections.ChainMap(*statistics_dicts))
|
||||
|
||||
|
||||
class EpisodeTracker:
|
||||
"""Tracks episode return and other statistics."""
|
||||
|
||||
def __init__(self):
|
||||
self._num_steps_since_reset = None
|
||||
self._num_steps_over_episodes = None
|
||||
self._episode_returns = None
|
||||
self._current_episode_rewards = None
|
||||
self._current_episode_step = None
|
||||
|
||||
def step(
|
||||
self,
|
||||
environment: Optional[dm_env.Environment],
|
||||
timestep_t: dm_env.TimeStep,
|
||||
agent: Optional[Agent],
|
||||
a_t: Optional[Action],
|
||||
) -> None:
|
||||
"""Accumulates statistics from timestep."""
|
||||
del (environment, agent, a_t)
|
||||
|
||||
if timestep_t.first():
|
||||
if self._current_episode_rewards:
|
||||
raise ValueError('Current episode reward list should be empty.')
|
||||
if self._current_episode_step != 0:
|
||||
raise ValueError('Current episode step should be zero.')
|
||||
else:
|
||||
# First reward is invalid, all other rewards are appended.
|
||||
self._current_episode_rewards.append(timestep_t.reward)
|
||||
|
||||
self._num_steps_since_reset += 1
|
||||
self._current_episode_step += 1
|
||||
|
||||
if timestep_t.last():
|
||||
self._episode_returns.append(sum(self._current_episode_rewards))
|
||||
self._current_episode_rewards = []
|
||||
self._num_steps_over_episodes += self._current_episode_step
|
||||
self._current_episode_step = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets all gathered statistics, not to be called between episodes."""
|
||||
self._num_steps_since_reset = 0
|
||||
self._num_steps_over_episodes = 0
|
||||
self._episode_returns = []
|
||||
self._current_episode_step = 0
|
||||
self._current_episode_rewards = []
|
||||
|
||||
def get(self) -> Mapping[Text, Union[int, float, None]]:
|
||||
"""Aggregates statistics and returns as a dictionary.
|
||||
|
||||
Here the convention is `episode_return` is set to `current_episode_return`
|
||||
if a full episode has not been encountered. Otherwise it is set to
|
||||
`mean_episode_return` which is the mean return of complete episodes only. If
|
||||
no steps have been taken at all, `episode_return` is set to `NaN`.
|
||||
|
||||
Returns:
|
||||
A dictionary of aggregated statistics.
|
||||
"""
|
||||
if self._episode_returns:
|
||||
mean_episode_return = np.array(self._episode_returns).mean()
|
||||
current_episode_return = sum(self._current_episode_rewards)
|
||||
episode_return = mean_episode_return
|
||||
else:
|
||||
mean_episode_return = np.nan
|
||||
if self._num_steps_since_reset > 0:
|
||||
current_episode_return = sum(self._current_episode_rewards)
|
||||
else:
|
||||
current_episode_return = np.nan
|
||||
episode_return = current_episode_return
|
||||
|
||||
return {
|
||||
'mean_episode_return': mean_episode_return,
|
||||
'current_episode_return': current_episode_return,
|
||||
'episode_return': episode_return,
|
||||
'num_episodes': len(self._episode_returns),
|
||||
'num_steps_over_episodes': self._num_steps_over_episodes,
|
||||
'current_episode_step': self._current_episode_step,
|
||||
'num_steps_since_reset': self._num_steps_since_reset,
|
||||
}
|
||||
|
||||
|
||||
class StepRateTracker:
|
||||
"""Tracks step rate, number of steps taken and duration since last reset."""
|
||||
|
||||
def __init__(self):
|
||||
self._num_steps_since_reset = None
|
||||
self._start = None
|
||||
|
||||
def step(
|
||||
self,
|
||||
environment: Optional[dm_env.Environment],
|
||||
timestep_t: Optional[dm_env.TimeStep],
|
||||
agent: Optional[Agent],
|
||||
a_t: Optional[Action],
|
||||
) -> None:
|
||||
del (environment, timestep_t, agent, a_t)
|
||||
self._num_steps_since_reset += 1
|
||||
|
||||
def reset(self) -> None:
|
||||
self._num_steps_since_reset = 0
|
||||
self._start = timeit.default_timer()
|
||||
|
||||
def get(self) -> Mapping[Text, float]:
|
||||
duration = timeit.default_timer() - self._start
|
||||
if self._num_steps_since_reset > 0:
|
||||
step_rate = self._num_steps_since_reset / duration
|
||||
else:
|
||||
step_rate = np.nan
|
||||
return {
|
||||
'step_rate': step_rate,
|
||||
'num_steps': self._num_steps_since_reset,
|
||||
'duration': duration,
|
||||
}
|
||||
|
||||
|
||||
class UnbiasedExponentialWeightedAverageAgentTracker:
|
||||
"""'Unbiased Constant-Step-Size Trick' from the Sutton and Barto RL book."""
|
||||
|
||||
def __init__(self, step_size: float, initial_agent: Agent):
|
||||
self._initial_statistics = dict(initial_agent.statistics)
|
||||
self._step_size = step_size
|
||||
self.trace = 0.
|
||||
self._statistics = dict(self._initial_statistics)
|
||||
|
||||
def step(
|
||||
self,
|
||||
environment: Optional[dm_env.Environment],
|
||||
timestep_t: Optional[dm_env.TimeStep],
|
||||
agent: Agent,
|
||||
a_t: Optional[Action],
|
||||
) -> None:
|
||||
"""Accumulates agent statistics."""
|
||||
del (environment, timestep_t, a_t)
|
||||
|
||||
self.trace = (1 - self._step_size) * self.trace + self._step_size
|
||||
final_step_size = self._step_size / self.trace
|
||||
assert 0 <= final_step_size <= 1
|
||||
|
||||
if final_step_size == 1:
|
||||
# Since the self._initial_statistics is likely to be NaN and
|
||||
# 0 * NaN == NaN just replace self._statistics on the first step.
|
||||
self._statistics = dict(agent.statistics)
|
||||
else:
|
||||
self._statistics = jax.tree_multimap(
|
||||
lambda s, x: (1 - final_step_size) * s + final_step_size * x,
|
||||
self._statistics, agent.statistics)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets statistics and internal state."""
|
||||
self.trace = 0.
|
||||
# get() may be called before step() so ensure statistics are initialized.
|
||||
self._statistics = dict(self._initial_statistics)
|
||||
|
||||
def get(self) -> Mapping[Text, float]:
|
||||
"""Returns current accumulated statistics."""
|
||||
return self._statistics
|
||||
|
||||
|
||||
def make_default_trackers(initial_agent: Agent):
|
||||
return [
|
||||
EpisodeTracker(),
|
||||
StepRateTracker(),
|
||||
UnbiasedExponentialWeightedAverageAgentTracker(
|
||||
step_size=1e-3, initial_agent=initial_agent),
|
||||
]
|
||||
|
||||
|
||||
class EpsilonGreedyActor(Agent):
|
||||
"""Agent that acts with a given set of Q-network parameters and epsilon.
|
||||
|
||||
Network parameters are set on the actor. The actor can be serialized,
|
||||
ensuring determinism of execution (e.g. when checkpointing).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preprocessor: processors.Processor,
|
||||
network: Network,
|
||||
exploration_epsilon: float,
|
||||
rng_key: PRNGKey,
|
||||
):
|
||||
self._preprocessor = preprocessor
|
||||
self._rng_key = rng_key
|
||||
self._action = None
|
||||
self.network_params = None # Nest of arrays (haiku.Params), set externally.
|
||||
|
||||
def select_action(rng_key, network_params, s_t):
|
||||
"""Samples action from eps-greedy policy wrt Q-values at given state."""
|
||||
rng_key, apply_key, policy_key = jax.random.split(rng_key, 3)
|
||||
q_t = network.apply(network_params, apply_key, s_t[None, ...]).q_values[0]
|
||||
a_t = rlax.epsilon_greedy().sample(policy_key, q_t, exploration_epsilon)
|
||||
return rng_key, a_t
|
||||
|
||||
self._select_action = jax.jit(select_action)
|
||||
|
||||
def step(self, timestep: dm_env.TimeStep) -> Action:
|
||||
"""Selects action given a timestep."""
|
||||
timestep = self._preprocessor(timestep)
|
||||
|
||||
if timestep is None: # Repeat action.
|
||||
return self._action
|
||||
|
||||
s_t = timestep.observation
|
||||
self._rng_key, a_t = self._select_action(self._rng_key, self.network_params,
|
||||
s_t)
|
||||
self._action = Action(jax.device_get(a_t))
|
||||
return self._action
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the agent's episodic state such as frame stack and action repeat.
|
||||
|
||||
This method should be called at the beginning of every episode.
|
||||
"""
|
||||
processors.reset(self._preprocessor)
|
||||
self._action = None
|
||||
|
||||
def get_state(self) -> Mapping[Text, Any]:
|
||||
"""Retrieves agent state as a dictionary (e.g. for serialization)."""
|
||||
# State contains network params to make agent easy to run from a checkpoint.
|
||||
return {
|
||||
'rng_key': self._rng_key,
|
||||
'network_params': self.network_params,
|
||||
}
|
||||
|
||||
def set_state(self, state: Mapping[Text, Any]) -> None:
|
||||
"""Sets agent state from a (potentially de-serialized) dictionary."""
|
||||
self._rng_key = state['rng_key']
|
||||
self.network_params = state['network_params']
|
||||
|
||||
@property
|
||||
def statistics(self) -> Mapping[Text, float]:
|
||||
return {}
|
||||
|
||||
|
||||
class LinearSchedule:
|
||||
"""Linear schedule, used for exploration epsilon in DQN agents."""
|
||||
|
||||
def __init__(self,
|
||||
begin_value,
|
||||
end_value,
|
||||
begin_t,
|
||||
end_t=None,
|
||||
decay_steps=None):
|
||||
if (end_t is None) == (decay_steps is None):
|
||||
raise ValueError('Exactly one of end_t, decay_steps must be provided.')
|
||||
self._decay_steps = decay_steps if end_t is None else end_t - begin_t
|
||||
self._begin_t = begin_t
|
||||
self._begin_value = begin_value
|
||||
self._end_value = end_value
|
||||
|
||||
def __call__(self, t):
|
||||
"""Implements a linear transition from a begin to an end value."""
|
||||
frac = min(max(t - self._begin_t, 0), self._decay_steps) / self._decay_steps
|
||||
return (1 - frac) * self._begin_value + frac * self._end_value
|
||||
|
||||
|
||||
class NullWriter:
|
||||
"""A placeholder logging object that does nothing."""
|
||||
|
||||
def write(self, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class CsvWriter:
|
||||
"""A logging object writing to a CSV file.
|
||||
|
||||
Each `write()` takes a `OrderedDict`, creating one column in the CSV file for
|
||||
each dictionary key on the first call. Successive calls to `write()` must
|
||||
contain the same dictionary keys.
|
||||
"""
|
||||
|
||||
def __init__(self, fname: Text):
|
||||
"""Initializes a `CsvWriter`.
|
||||
|
||||
Args:
|
||||
fname: File name (path) for file to be written to.
|
||||
"""
|
||||
dirname = os.path.dirname(fname)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
self._fname = fname
|
||||
self._header_written = False
|
||||
self._fieldnames = None
|
||||
|
||||
def write(self, values: collections.OrderedDict) -> None:
|
||||
"""Appends given values as new row to CSV file."""
|
||||
if self._fieldnames is None:
|
||||
self._fieldnames = values.keys()
|
||||
# Open a file in 'append' mode, so we can continue logging safely to the
|
||||
# same file after e.g. restarting from a checkpoint.
|
||||
with open(self._fname, 'a') as file:
|
||||
# Always use same fieldnames to create writer, this way a consistency
|
||||
# check is performed automatically on each write.
|
||||
writer = csv.DictWriter(file, fieldnames=self._fieldnames)
|
||||
# Write a header if this is the very first write.
|
||||
if not self._header_written:
|
||||
writer.writeheader()
|
||||
self._header_written = True
|
||||
writer.writerow(values)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the `CsvWriter`."""
|
||||
pass
|
||||
|
||||
def get_state(self) -> Mapping[Text, Any]:
|
||||
"""Retrieves `CsvWriter` state as a `dict` (e.g. for serialization)."""
|
||||
return {
|
||||
'header_written': self._header_written,
|
||||
'fieldnames': self._fieldnames
|
||||
}
|
||||
|
||||
def set_state(self, state: Mapping[Text, Any]) -> None:
|
||||
"""Sets `CsvWriter` state from a (potentially de-serialized) dictionary."""
|
||||
self._header_written = state['header_written']
|
||||
self._fieldnames = state['fieldnames']
|
||||
|
||||
|
||||
class NullCheckpoint:
|
||||
"""A placeholder checkpointing object that does nothing.
|
||||
|
||||
Can be used as a substitute for an actual checkpointing object when
|
||||
checkpointing is disabled.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.state = AttributeDict()
|
||||
|
||||
def save(self) -> None:
|
||||
pass
|
||||
|
||||
def can_be_restored(self) -> bool:
|
||||
return False
|
||||
|
||||
def restore(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AttributeDict(dict):
|
||||
"""A `dict` that supports getting, setting, deleting keys via attributes."""
|
||||
|
||||
def __getattr__(self, key):
|
||||
return self[key]
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
def __delattr__(self, key):
|
||||
del self[key]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,167 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Replay components for DQN-type agents."""
|
||||
|
||||
import collections
|
||||
import typing
|
||||
from typing import Any, Callable, Generic, Iterable, List, Mapping, Optional, Sequence, Text, Tuple, TypeVar
|
||||
|
||||
import dm_env
|
||||
import numpy as np
|
||||
import snappy
|
||||
|
||||
from tandem_dqn import parts
|
||||
|
||||
CompressedArray = Tuple[bytes, Tuple, np.dtype]
|
||||
|
||||
# Generic replay structure: Any flat named tuple.
|
||||
ReplayStructure = TypeVar('ReplayStructure', bound=Tuple[Any, ...])
|
||||
|
||||
|
||||
class Transition(typing.NamedTuple):
|
||||
s_tm1: Optional[np.ndarray]
|
||||
a_tm1: Optional[parts.Action]
|
||||
r_t: Optional[float]
|
||||
discount_t: Optional[float]
|
||||
s_t: Optional[np.ndarray]
|
||||
a_t: Optional[parts.Action] = None
|
||||
mc_return_tm1: Optional[float] = None
|
||||
|
||||
|
||||
class TransitionReplay(Generic[ReplayStructure]):
|
||||
"""Uniform replay, with circular buffer storage for flat named tuples."""
|
||||
|
||||
def __init__(self,
|
||||
capacity: int,
|
||||
structure: ReplayStructure,
|
||||
random_state: np.random.RandomState,
|
||||
encoder: Optional[Callable[[ReplayStructure], Any]] = None,
|
||||
decoder: Optional[Callable[[Any], ReplayStructure]] = None):
|
||||
self._capacity = capacity
|
||||
self._structure = structure
|
||||
self._random_state = random_state
|
||||
self._encoder = encoder or (lambda s: s)
|
||||
self._decoder = decoder or (lambda s: s)
|
||||
|
||||
self._storage = [None] * capacity
|
||||
self._num_added = 0
|
||||
|
||||
def add(self, item: ReplayStructure) -> None:
|
||||
"""Adds single item to replay."""
|
||||
self._storage[self._num_added % self._capacity] = self._encoder(item)
|
||||
self._num_added += 1
|
||||
|
||||
def get(self, indices: Sequence[int]) -> List[ReplayStructure]:
|
||||
"""Retrieves items by indices."""
|
||||
return [self._decoder(self._storage[i]) for i in indices]
|
||||
|
||||
def sample(self, size: int) -> ReplayStructure:
|
||||
"""Samples batch of items from replay uniformly, with replacement."""
|
||||
indices = self._random_state.choice(self.size, size=size, replace=True)
|
||||
samples = self.get(indices)
|
||||
transposed = zip(*samples)
|
||||
stacked = [np.stack(xs, axis=0) for xs in transposed]
|
||||
return type(self._structure)(*stacked) # pytype: disable=not-callable
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Number of items currently contained in replay."""
|
||||
return min(self._num_added, self._capacity)
|
||||
|
||||
@property
|
||||
def capacity(self) -> int:
|
||||
"""Total capacity of replay (max number of items stored at any one time)."""
|
||||
return self._capacity
|
||||
|
||||
def get_state(self) -> Mapping[Text, Any]:
|
||||
"""Retrieves replay state as a dictionary (e.g. for serialization)."""
|
||||
return {
|
||||
'storage': self._storage,
|
||||
'num_added': self._num_added,
|
||||
}
|
||||
|
||||
def set_state(self, state: Mapping[Text, Any]) -> None:
|
||||
"""Sets replay state from a (potentially de-serialized) dictionary."""
|
||||
self._storage = state['storage']
|
||||
self._num_added = state['num_added']
|
||||
|
||||
|
||||
class TransitionAccumulatorWithMCReturn:
|
||||
"""Accumulates timesteps to transitions with MC returns."""
|
||||
|
||||
def __init__(self):
|
||||
self._transitions = collections.deque()
|
||||
self.reset()
|
||||
|
||||
def step(self, timestep_t: dm_env.TimeStep,
|
||||
a_t: parts.Action) -> Iterable[Transition]:
|
||||
"""Accumulates timestep and resulting action, maybe yields transitions."""
|
||||
if timestep_t.first():
|
||||
self.reset()
|
||||
|
||||
# There are no transitions on the first timestep.
|
||||
if self._timestep_tm1 is None:
|
||||
assert self._a_tm1 is None
|
||||
if not timestep_t.first():
|
||||
raise ValueError('Expected FIRST timestep, got %s.' % str(timestep_t))
|
||||
self._timestep_tm1 = timestep_t
|
||||
self._a_tm1 = a_t
|
||||
return # Empty iterable.
|
||||
|
||||
self._transitions.append(
|
||||
Transition(
|
||||
s_tm1=self._timestep_tm1.observation,
|
||||
a_tm1=self._a_tm1,
|
||||
r_t=timestep_t.reward,
|
||||
discount_t=timestep_t.discount,
|
||||
s_t=timestep_t.observation,
|
||||
a_t=a_t,
|
||||
mc_return_tm1=None,
|
||||
))
|
||||
|
||||
self._timestep_tm1 = timestep_t
|
||||
self._a_tm1 = a_t
|
||||
|
||||
if timestep_t.last():
|
||||
# Annotate all episode transitions with their MC returns.
|
||||
mc_return = 0
|
||||
mc_transitions = []
|
||||
while self._transitions:
|
||||
transition = self._transitions.pop()
|
||||
mc_return = transition.discount_t * mc_return + transition.r_t
|
||||
mc_transitions.append(transition._replace(mc_return_tm1=mc_return))
|
||||
for transition in reversed(mc_transitions):
|
||||
yield transition
|
||||
|
||||
else:
|
||||
# Wait for episode end before yielding anything.
|
||||
return
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Resets the accumulator. Following timestep is expected to be FIRST."""
|
||||
self._transitions.clear()
|
||||
self._timestep_tm1 = None
|
||||
self._a_tm1 = None
|
||||
|
||||
|
||||
def compress_array(array: np.ndarray) -> CompressedArray:
|
||||
"""Compresses a numpy array with snappy."""
|
||||
return snappy.compress(array), array.shape, array.dtype
|
||||
|
||||
|
||||
def uncompress_array(compressed: CompressedArray) -> np.ndarray:
|
||||
"""Uncompresses a numpy array with snappy given its shape and dtype."""
|
||||
compressed_array, shape, dtype = compressed
|
||||
byte_string = snappy.uncompress(compressed_array)
|
||||
return np.frombuffer(byte_string, dtype=dtype).reshape(shape)
|
||||
@@ -0,0 +1,38 @@
|
||||
# DeepMind libraries.
|
||||
chex==0.0.2
|
||||
dm-env==1.2
|
||||
dm-haiku @ git+git://github.com/deepmind/dm-haiku.git@db991d56563221d5a06be5e7228155e53d01aba9
|
||||
dm-tree==0.1.5
|
||||
optax==0.0.1
|
||||
rlax @ git+git://github.com/deepmind/rlax.git@870cba1ea8ad36725f4f3a790846298657b6fd4b
|
||||
|
||||
# JAX libraries with GPU support.
|
||||
jax==0.1.72
|
||||
jaxlib==0.1.50
|
||||
|
||||
# ALE and Gym libraries.
|
||||
atari-py==0.2.6
|
||||
gym==0.13.1
|
||||
|
||||
# Base dependencies.
|
||||
absl-py==0.9.0
|
||||
numpy==1.18.0
|
||||
Pillow==7.1.2
|
||||
python-snappy==0.6
|
||||
|
||||
# Transitive dependencies.
|
||||
asn1crypto==0.24.0
|
||||
cloudpickle==1.2.2
|
||||
cryptography==2.1.4
|
||||
dataclasses==0.7
|
||||
future==0.18.2
|
||||
idna==2.6
|
||||
keyring==10.6.0
|
||||
keyrings.alt==3.0
|
||||
opencv-python==4.2.0.34
|
||||
opt-einsum==3.2.1
|
||||
pycrypto==2.6.1
|
||||
pyxdg==0.25
|
||||
SecretStorage==2.3.1
|
||||
six==1.15.0
|
||||
toolz==0.11.1
|
||||
@@ -0,0 +1,46 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script installs tandem_dqn in a clean virtualenv and runs an
|
||||
# example training loop. It is designed to be run from the parent directory,
|
||||
# e.g.:
|
||||
#
|
||||
# git clone git@github.com:deepmind/deepmind-research.git
|
||||
# cd deepmind_research
|
||||
# tandem_dqn/run.sh
|
||||
|
||||
# Fail on errors, display commands.
|
||||
set -e
|
||||
set -x
|
||||
|
||||
TMP_DIR=`mktemp -d`
|
||||
|
||||
# Set up environment.
|
||||
virtualenv --python=python3.6 "${TMP_DIR}/env"
|
||||
source "${TMP_DIR}/env/bin/activate"
|
||||
|
||||
# Install deps.
|
||||
pip install --upgrade -r tandem_dqn/requirements.txt
|
||||
|
||||
# Run small-scale example for a few thousand steps.
|
||||
python -m tandem_dqn.run_tandem \
|
||||
--environment_name=space_invaders \
|
||||
--replay_capacity=10000 --target_network_update_period=40 \
|
||||
--num_iterations=2 --num_train_frames=5000 --num_eval_frames=500 \
|
||||
--jax_platform_name=cpu
|
||||
|
||||
# Clean up.
|
||||
rm -r ${TMP_DIR}
|
||||
echo "Test run finished."
|
||||
@@ -0,0 +1,356 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A Tandem DQN agent implemented in JAX, training on Atari."""
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
import sys
|
||||
import typing
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import dm_env
|
||||
import jax
|
||||
from jax.config import config
|
||||
import numpy as np
|
||||
import optax
|
||||
|
||||
from tandem_dqn import agent as agent_lib
|
||||
from tandem_dqn import atari_data
|
||||
from tandem_dqn import gym_atari
|
||||
from tandem_dqn import losses
|
||||
from tandem_dqn import networks
|
||||
from tandem_dqn import parts
|
||||
from tandem_dqn import processors
|
||||
from tandem_dqn import replay as replay_lib
|
||||
|
||||
# Relevant flag values are expressed in terms of environment frames.
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('environment_name', 'pong', '')
|
||||
flags.DEFINE_boolean('use_sticky_actions', False, '')
|
||||
flags.DEFINE_integer('environment_height', 84, '')
|
||||
flags.DEFINE_integer('environment_width', 84, '')
|
||||
flags.DEFINE_integer('replay_capacity', int(1e6), '')
|
||||
flags.DEFINE_bool('compress_state', True, '')
|
||||
flags.DEFINE_float('min_replay_capacity_fraction', 0.05, '')
|
||||
flags.DEFINE_integer('batch_size', 32, '')
|
||||
flags.DEFINE_integer('max_frames_per_episode', 108000, '') # 30 mins.
|
||||
flags.DEFINE_integer('num_action_repeats', 4, '')
|
||||
flags.DEFINE_integer('num_stacked_frames', 4, '')
|
||||
flags.DEFINE_float('exploration_epsilon_begin_value', 1., '')
|
||||
flags.DEFINE_float('exploration_epsilon_end_value', 0.01, '')
|
||||
flags.DEFINE_float('exploration_epsilon_decay_frame_fraction', 0.02, '')
|
||||
flags.DEFINE_float('eval_exploration_epsilon', 0.01, '')
|
||||
flags.DEFINE_integer('target_network_update_period', int(1.2e5), '')
|
||||
flags.DEFINE_float('additional_discount', 0.99, '')
|
||||
flags.DEFINE_float('max_abs_reward', 1., '')
|
||||
flags.DEFINE_integer('seed', 1, '') # GPU may introduce nondeterminism.
|
||||
flags.DEFINE_integer('num_iterations', 200, '')
|
||||
flags.DEFINE_integer('num_train_frames', int(1e6), '') # Per iteration.
|
||||
flags.DEFINE_integer('num_eval_frames', int(5e5), '') # Per iteration.
|
||||
flags.DEFINE_integer('learn_period', 16, '')
|
||||
flags.DEFINE_string('results_csv_path', '/tmp/results.csv', '')
|
||||
|
||||
# Tandem-specific parameters.
|
||||
|
||||
# Using fixed configs for optimizers:
|
||||
# RMSProp: lr = 0.00025, eps=0.01 / (32 ** 2)
|
||||
# ADAM: lr = 0.00005, eps=0.01 / 32
|
||||
_OPTIMIZERS = ['rmsprop', 'adam']
|
||||
flags.DEFINE_enum('optimizer_active', 'rmsprop', _OPTIMIZERS, '')
|
||||
flags.DEFINE_enum('optimizer_passive', 'rmsprop', _OPTIMIZERS, '')
|
||||
|
||||
_NETWORKS = ['double_q', 'qr']
|
||||
flags.DEFINE_enum('network_active', 'double_q', _NETWORKS, '')
|
||||
flags.DEFINE_enum('network_passive', 'double_q', _NETWORKS, '')
|
||||
|
||||
_LOSSES = ['double_q', 'double_q_v', 'double_q_p', 'double_q_pv', 'qr',
|
||||
'q_regression']
|
||||
flags.DEFINE_enum('loss_active', 'double_q', _LOSSES, '')
|
||||
flags.DEFINE_enum('loss_passive', 'double_q', _LOSSES, '')
|
||||
|
||||
flags.DEFINE_integer('tied_layers', 0, '')
|
||||
|
||||
TandemTuple = agent_lib.TandemTuple
|
||||
|
||||
|
||||
def make_optimizer(optimizer_type):
|
||||
"""Constructs optimizer."""
|
||||
if optimizer_type == 'rmsprop':
|
||||
learning_rate = 0.00025
|
||||
epsilon = 0.01 / (32**2)
|
||||
optimizer = optax.rmsprop(
|
||||
learning_rate=learning_rate,
|
||||
decay=0.95,
|
||||
eps=epsilon,
|
||||
centered=True)
|
||||
elif optimizer_type == 'adam':
|
||||
learning_rate = 0.00005
|
||||
epsilon = 0.01 / 32
|
||||
optimizer = optax.adam(
|
||||
learning_rate=learning_rate,
|
||||
eps=epsilon)
|
||||
else:
|
||||
raise ValueError('Unknown optimizer "{}"'.format(optimizer_type))
|
||||
return optimizer
|
||||
|
||||
|
||||
def main(argv):
|
||||
"""Trains Tandem DQN agent on Atari."""
|
||||
del argv
|
||||
logging.info('Tandem DQN on Atari on %s.',
|
||||
jax.lib.xla_bridge.get_backend().platform)
|
||||
random_state = np.random.RandomState(FLAGS.seed)
|
||||
rng_key = jax.random.PRNGKey(
|
||||
random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, dtype=np.int64))
|
||||
|
||||
if FLAGS.results_csv_path:
|
||||
writer = parts.CsvWriter(FLAGS.results_csv_path)
|
||||
else:
|
||||
writer = parts.NullWriter()
|
||||
|
||||
def environment_builder():
|
||||
"""Creates Atari environment."""
|
||||
env = gym_atari.GymAtari(
|
||||
FLAGS.environment_name,
|
||||
sticky_actions=FLAGS.use_sticky_actions,
|
||||
seed=random_state.randint(1, 2**32))
|
||||
return gym_atari.RandomNoopsEnvironmentWrapper(
|
||||
env,
|
||||
min_noop_steps=1,
|
||||
max_noop_steps=30,
|
||||
seed=random_state.randint(1, 2**32),
|
||||
)
|
||||
|
||||
env = environment_builder()
|
||||
|
||||
logging.info('Environment: %s', FLAGS.environment_name)
|
||||
logging.info('Action spec: %s', env.action_spec())
|
||||
logging.info('Observation spec: %s', env.observation_spec())
|
||||
num_actions = env.action_spec().num_values
|
||||
|
||||
# Check: qr network and qr losses can only be used together.
|
||||
if ('qr' in FLAGS.network_active) != ('qr' in FLAGS.loss_active):
|
||||
raise ValueError('Active loss/net must either both use QR, or neither.')
|
||||
if ('qr' in FLAGS.network_passive) != ('qr' in FLAGS.loss_passive):
|
||||
raise ValueError('Passive loss/net must either both use QR, or neither.')
|
||||
network = TandemTuple(
|
||||
active=networks.make_network(FLAGS.network_active, num_actions),
|
||||
passive=networks.make_network(FLAGS.network_passive, num_actions),
|
||||
)
|
||||
loss = TandemTuple(
|
||||
active=losses.make_loss_fn(FLAGS.loss_active, active=True),
|
||||
passive=losses.make_loss_fn(FLAGS.loss_passive, active=False),
|
||||
)
|
||||
|
||||
# Tied layers.
|
||||
assert 0 <= FLAGS.tied_layers <= 4
|
||||
if FLAGS.tied_layers > 0 and (FLAGS.network_passive != 'double_q'
|
||||
or FLAGS.network_active != 'double_q'):
|
||||
raise ValueError('Tied layers > 0 is only supported for double_q networks.')
|
||||
layers = [
|
||||
'sequential/sequential/conv1',
|
||||
'sequential/sequential/conv2',
|
||||
'sequential/sequential/conv3',
|
||||
'sequential/sequential_1/linear1'
|
||||
]
|
||||
tied_layers = set(layers[:FLAGS.tied_layers])
|
||||
|
||||
def preprocessor_builder():
|
||||
return processors.atari(
|
||||
additional_discount=FLAGS.additional_discount,
|
||||
max_abs_reward=FLAGS.max_abs_reward,
|
||||
resize_shape=(FLAGS.environment_height, FLAGS.environment_width),
|
||||
num_action_repeats=FLAGS.num_action_repeats,
|
||||
num_pooled_frames=2,
|
||||
zero_discount_on_life_loss=True,
|
||||
num_stacked_frames=FLAGS.num_stacked_frames,
|
||||
grayscaling=True,
|
||||
)
|
||||
|
||||
# Create sample network input from sample preprocessor output.
|
||||
sample_processed_timestep = preprocessor_builder()(env.reset())
|
||||
sample_processed_timestep = typing.cast(dm_env.TimeStep,
|
||||
sample_processed_timestep)
|
||||
sample_network_input = sample_processed_timestep.observation
|
||||
assert sample_network_input.shape == (FLAGS.environment_height,
|
||||
FLAGS.environment_width,
|
||||
FLAGS.num_stacked_frames)
|
||||
|
||||
exploration_epsilon_schedule = parts.LinearSchedule(
|
||||
begin_t=int(FLAGS.min_replay_capacity_fraction * FLAGS.replay_capacity *
|
||||
FLAGS.num_action_repeats),
|
||||
decay_steps=int(FLAGS.exploration_epsilon_decay_frame_fraction *
|
||||
FLAGS.num_iterations * FLAGS.num_train_frames),
|
||||
begin_value=FLAGS.exploration_epsilon_begin_value,
|
||||
end_value=FLAGS.exploration_epsilon_end_value)
|
||||
|
||||
if FLAGS.compress_state:
|
||||
|
||||
def encoder(transition):
|
||||
return transition._replace(
|
||||
s_tm1=replay_lib.compress_array(transition.s_tm1),
|
||||
s_t=replay_lib.compress_array(transition.s_t))
|
||||
|
||||
def decoder(transition):
|
||||
return transition._replace(
|
||||
s_tm1=replay_lib.uncompress_array(transition.s_tm1),
|
||||
s_t=replay_lib.uncompress_array(transition.s_t))
|
||||
else:
|
||||
encoder = None
|
||||
decoder = None
|
||||
|
||||
replay_structure = replay_lib.Transition(
|
||||
s_tm1=None,
|
||||
a_tm1=None,
|
||||
r_t=None,
|
||||
discount_t=None,
|
||||
s_t=None,
|
||||
a_t=None,
|
||||
mc_return_tm1=None,
|
||||
)
|
||||
|
||||
replay = replay_lib.TransitionReplay(FLAGS.replay_capacity, replay_structure,
|
||||
random_state, encoder, decoder)
|
||||
|
||||
optimizer = TandemTuple(
|
||||
active=make_optimizer(FLAGS.optimizer_active),
|
||||
passive=make_optimizer(FLAGS.optimizer_passive),
|
||||
)
|
||||
|
||||
train_rng_key, eval_rng_key = jax.random.split(rng_key)
|
||||
|
||||
train_agent = agent_lib.TandemDqn(
|
||||
preprocessor=preprocessor_builder(),
|
||||
sample_network_input=sample_network_input,
|
||||
network=network,
|
||||
optimizer=optimizer,
|
||||
loss=loss,
|
||||
transition_accumulator=replay_lib.TransitionAccumulatorWithMCReturn(),
|
||||
replay=replay,
|
||||
batch_size=FLAGS.batch_size,
|
||||
exploration_epsilon=exploration_epsilon_schedule,
|
||||
min_replay_capacity_fraction=FLAGS.min_replay_capacity_fraction,
|
||||
learn_period=FLAGS.learn_period,
|
||||
target_network_update_period=FLAGS.target_network_update_period,
|
||||
tied_layers=tied_layers,
|
||||
rng_key=train_rng_key,
|
||||
)
|
||||
eval_agent_active = parts.EpsilonGreedyActor(
|
||||
preprocessor=preprocessor_builder(),
|
||||
network=network.active,
|
||||
exploration_epsilon=FLAGS.eval_exploration_epsilon,
|
||||
rng_key=eval_rng_key)
|
||||
eval_agent_passive = parts.EpsilonGreedyActor(
|
||||
preprocessor=preprocessor_builder(),
|
||||
network=network.passive,
|
||||
exploration_epsilon=FLAGS.eval_exploration_epsilon,
|
||||
rng_key=eval_rng_key)
|
||||
|
||||
# Set up checkpointing.
|
||||
checkpoint = parts.NullCheckpoint()
|
||||
|
||||
state = checkpoint.state
|
||||
state.iteration = 0
|
||||
state.train_agent = train_agent
|
||||
state.eval_agent_active = eval_agent_active
|
||||
state.eval_agent_passive = eval_agent_passive
|
||||
state.random_state = random_state
|
||||
state.writer = writer
|
||||
if checkpoint.can_be_restored():
|
||||
checkpoint.restore()
|
||||
|
||||
# Run single iteration of training or evaluation.
|
||||
def run_iteration(agent, env, num_frames):
|
||||
seq = parts.run_loop(agent, env, FLAGS.max_frames_per_episode)
|
||||
seq_truncated = itertools.islice(seq, num_frames)
|
||||
trackers = parts.make_default_trackers(agent)
|
||||
return parts.generate_statistics(trackers, seq_truncated)
|
||||
|
||||
def eval_log_output(eval_stats, suffix):
|
||||
human_normalized_score = atari_data.get_human_normalized_score(
|
||||
FLAGS.environment_name, eval_stats['episode_return'])
|
||||
capped_human_normalized_score = np.amin([1., human_normalized_score])
|
||||
return [
|
||||
('eval_episode_return_' + suffix,
|
||||
eval_stats['episode_return'], '% 2.2f'),
|
||||
('eval_num_episodes_' + suffix,
|
||||
eval_stats['num_episodes'], '%3d'),
|
||||
('eval_frame_rate_' + suffix,
|
||||
eval_stats['step_rate'], '%4.0f'),
|
||||
('normalized_return_' + suffix,
|
||||
human_normalized_score, '%.3f'),
|
||||
('capped_normalized_return_' + suffix,
|
||||
capped_human_normalized_score, '%.3f'),
|
||||
('human_gap_' + suffix,
|
||||
1. - capped_human_normalized_score, '%.3f'),
|
||||
]
|
||||
|
||||
while state.iteration <= FLAGS.num_iterations:
|
||||
# New environment for each iteration to allow for determinism if preempted.
|
||||
env = environment_builder()
|
||||
|
||||
# Set agent to train active and passive nets on each learning step.
|
||||
train_agent.set_training_mode('active_passive')
|
||||
|
||||
logging.info('Training iteration %d.', state.iteration)
|
||||
num_train_frames = 0 if state.iteration == 0 else FLAGS.num_train_frames
|
||||
train_stats = run_iteration(train_agent, env, num_train_frames)
|
||||
|
||||
logging.info('Evaluation iteration %d - active agent.', state.iteration)
|
||||
eval_agent_active.network_params = train_agent.online_params.active
|
||||
eval_stats_active = run_iteration(eval_agent_active, env,
|
||||
FLAGS.num_eval_frames)
|
||||
|
||||
logging.info('Evaluation iteration %d - passive agent.', state.iteration)
|
||||
eval_agent_passive.network_params = train_agent.online_params.passive
|
||||
eval_stats_passive = run_iteration(eval_agent_passive, env,
|
||||
FLAGS.num_eval_frames)
|
||||
|
||||
# Logging and checkpointing.
|
||||
agent_logs = [
|
||||
'loss_active',
|
||||
'loss_passive',
|
||||
'frac_diff_argmax',
|
||||
'mc_error_active',
|
||||
'mc_error_passive',
|
||||
'mc_error_abs_active',
|
||||
'mc_error_abs_passive',
|
||||
]
|
||||
log_output = (
|
||||
eval_log_output(eval_stats_active, 'active') +
|
||||
eval_log_output(eval_stats_passive, 'passive') +
|
||||
[('iteration', state.iteration, '%3d'),
|
||||
('frame', state.iteration * FLAGS.num_train_frames, '%5d'),
|
||||
('train_episode_return', train_stats['episode_return'], '% 2.2f'),
|
||||
('train_num_episodes', train_stats['num_episodes'], '%3d'),
|
||||
('train_frame_rate', train_stats['step_rate'], '%4.0f'),
|
||||
] +
|
||||
[(k, train_stats[k], '% 2.2f') for k in agent_logs]
|
||||
)
|
||||
log_output_str = ', '.join(('%s: ' + f) % (n, v) for n, v, f in log_output)
|
||||
logging.info(log_output_str)
|
||||
writer.write(collections.OrderedDict((n, v) for n, v, _ in log_output))
|
||||
state.iteration += 1
|
||||
checkpoint.save()
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.update('jax_platform_name', 'gpu') # Default to GPU.
|
||||
config.update('jax_numpy_rank_promotion', 'raise')
|
||||
config.config_with_absl()
|
||||
app.run(main)
|
||||
Reference in New Issue
Block a user