An example of training CRR agent with TPU support

PiperOrigin-RevId: 334117027
This commit is contained in:
Alexander Novikov
2020-09-28 10:53:06 +01:00
committed by Saran Tunyasuvunakool
parent 0e5237df2a
commit 99976cfaa9
3 changed files with 616 additions and 5 deletions
+9 -5
View File
@@ -33,6 +33,7 @@ import os
from typing import Dict, Optional, Tuple, Set from typing import Dict, Optional, Tuple, Set
from acme import wrappers from acme import wrappers
from acme.adders import reverb as adders
from dm_control import composer from dm_control import composer
from dm_control import suite from dm_control import suite
from dm_control.composer.variation import colors from dm_control.composer.variation import colors
@@ -726,16 +727,19 @@ def _parse_seq_tf_example(example, uint8_features, shapes):
def _build_sequence_example(sequences): def _build_sequence_example(sequences):
"""Convert raw sequences into a Reverb sequence sample.""" """Convert raw sequences into a Reverb sequence sample."""
o = sequences['observation'] data = adders.Step(
a = sequences['action'] observation=sequences['observation'],
r = sequences['reward'] action=sequences['action'],
p = sequences['discount'] reward=sequences['reward'],
discount=sequences['discount'],
start_of_episode=(),
extras=())
info = reverb.SampleInfo(key=tf.constant(0, tf.uint64), info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),
probability=tf.constant(1.0, tf.float64), probability=tf.constant(1.0, tf.float64),
table_size=tf.constant(0, tf.int64), table_size=tf.constant(0, tf.int64),
priority=tf.constant(1.0, tf.float64)) priority=tf.constant(1.0, tf.float64))
return reverb.ReplaySample(info=info, data=(o, a, r, p)) return reverb.ReplaySample(info=info, data=data)
def _build_sarsa_example(sequences): def _build_sarsa_example(sequences):
File diff suppressed because it is too large Load Diff
+93
View File
@@ -0,0 +1,93 @@
# Lint as: python3
# Copyright 2020 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Networks used for training agents.
"""
from acme.tf import networks as acme_networks
from acme.tf import utils as tf2_utils
import numpy as np
import sonnet as snt
import tensorflow as tf
def instance_norm_and_elu(x):
mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
x_ = x - mean
var = tf.reduce_mean(x_**2, axis=[1, 2], keepdims=True)
x_norm = x_ / (var + 1e-6)
return tf.nn.elu(x_norm)
class ControlNetwork(snt.Module):
"""Image, proprio and optionally action encoder used for actors and critics.
"""
def __init__(self,
proprio_encoder_size: int,
proprio_keys=None,
activation=tf.nn.elu):
"""Creates a ControlNetwork.
Args:
proprio_encoder_size: Size of the linear layer for the proprio encoder.
proprio_keys: Optional list of names of proprioceptive observations.
Defaults to all observations. Note that if this is specified, any
observation not contained in proprio_keys will be ignored by the agent.
activation: Linear layer activation function.
"""
super().__init__(name='control_network')
self._activation = activation
self._proprio_keys = proprio_keys
self._proprio_encoder = acme_networks.LayerNormMLP([proprio_encoder_size])
def __call__(self, inputs, action: tf.Tensor = None, task=None):
"""Evaluates the ControlNetwork.
Args:
inputs: A dictionary of agent observation tensors.
action: Agent actions.
task: Optional encoding of the task.
Raises:
ValueError: if neither proprio_input is provided.
ValueError: if some proprio input looks suspiciously like pixel inputs.
Returns:
Processed network output.
"""
if not isinstance(inputs, dict):
inputs = {'inputs': inputs}
proprio_input = []
# By default, treat all observations as proprioceptive.
if self._proprio_keys is None:
self._proprio_keys = list(sorted(inputs.keys()))
for key in self._proprio_keys:
proprio_input.append(snt.Flatten()(inputs[key]))
if np.prod(inputs[key].shape[1:]) > 32*32*3:
raise ValueError(
'This input does not resemble a proprioceptive '
'state: {} with shape {}'.format(
key, inputs[key].shape))
# Append optional action input (i.e. for critic networks).
if action is not None:
proprio_input.append(action)
proprio_input = tf2_utils.batch_concat(proprio_input)
proprio_state = self._proprio_encoder(proprio_input)
return proprio_state