mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
An example of training CRR agent with TPU support
PiperOrigin-RevId: 334117027
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
0e5237df2a
commit
99976cfaa9
@@ -33,6 +33,7 @@ import os
|
||||
from typing import Dict, Optional, Tuple, Set
|
||||
|
||||
from acme import wrappers
|
||||
from acme.adders import reverb as adders
|
||||
from dm_control import composer
|
||||
from dm_control import suite
|
||||
from dm_control.composer.variation import colors
|
||||
@@ -726,16 +727,19 @@ def _parse_seq_tf_example(example, uint8_features, shapes):
|
||||
|
||||
def _build_sequence_example(sequences):
|
||||
"""Convert raw sequences into a Reverb sequence sample."""
|
||||
o = sequences['observation']
|
||||
a = sequences['action']
|
||||
r = sequences['reward']
|
||||
p = sequences['discount']
|
||||
data = adders.Step(
|
||||
observation=sequences['observation'],
|
||||
action=sequences['action'],
|
||||
reward=sequences['reward'],
|
||||
discount=sequences['discount'],
|
||||
start_of_episode=(),
|
||||
extras=())
|
||||
|
||||
info = reverb.SampleInfo(key=tf.constant(0, tf.uint64),
|
||||
probability=tf.constant(1.0, tf.float64),
|
||||
table_size=tf.constant(0, tf.int64),
|
||||
priority=tf.constant(1.0, tf.float64))
|
||||
return reverb.ReplaySample(info=info, data=(o, a, r, p))
|
||||
return reverb.ReplaySample(info=info, data=data)
|
||||
|
||||
|
||||
def _build_sarsa_example(sequences):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,93 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Networks used for training agents.
|
||||
"""
|
||||
|
||||
from acme.tf import networks as acme_networks
|
||||
from acme.tf import utils as tf2_utils
|
||||
import numpy as np
|
||||
import sonnet as snt
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def instance_norm_and_elu(x):
|
||||
mean = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
|
||||
x_ = x - mean
|
||||
var = tf.reduce_mean(x_**2, axis=[1, 2], keepdims=True)
|
||||
x_norm = x_ / (var + 1e-6)
|
||||
return tf.nn.elu(x_norm)
|
||||
|
||||
|
||||
class ControlNetwork(snt.Module):
|
||||
"""Image, proprio and optionally action encoder used for actors and critics.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
proprio_encoder_size: int,
|
||||
proprio_keys=None,
|
||||
activation=tf.nn.elu):
|
||||
"""Creates a ControlNetwork.
|
||||
|
||||
Args:
|
||||
proprio_encoder_size: Size of the linear layer for the proprio encoder.
|
||||
proprio_keys: Optional list of names of proprioceptive observations.
|
||||
Defaults to all observations. Note that if this is specified, any
|
||||
observation not contained in proprio_keys will be ignored by the agent.
|
||||
activation: Linear layer activation function.
|
||||
"""
|
||||
super().__init__(name='control_network')
|
||||
self._activation = activation
|
||||
self._proprio_keys = proprio_keys
|
||||
|
||||
self._proprio_encoder = acme_networks.LayerNormMLP([proprio_encoder_size])
|
||||
|
||||
def __call__(self, inputs, action: tf.Tensor = None, task=None):
|
||||
"""Evaluates the ControlNetwork.
|
||||
|
||||
Args:
|
||||
inputs: A dictionary of agent observation tensors.
|
||||
action: Agent actions.
|
||||
task: Optional encoding of the task.
|
||||
|
||||
Raises:
|
||||
ValueError: if neither proprio_input is provided.
|
||||
ValueError: if some proprio input looks suspiciously like pixel inputs.
|
||||
|
||||
Returns:
|
||||
Processed network output.
|
||||
"""
|
||||
if not isinstance(inputs, dict):
|
||||
inputs = {'inputs': inputs}
|
||||
|
||||
proprio_input = []
|
||||
# By default, treat all observations as proprioceptive.
|
||||
if self._proprio_keys is None:
|
||||
self._proprio_keys = list(sorted(inputs.keys()))
|
||||
for key in self._proprio_keys:
|
||||
proprio_input.append(snt.Flatten()(inputs[key]))
|
||||
if np.prod(inputs[key].shape[1:]) > 32*32*3:
|
||||
raise ValueError(
|
||||
'This input does not resemble a proprioceptive '
|
||||
'state: {} with shape {}'.format(
|
||||
key, inputs[key].shape))
|
||||
|
||||
# Append optional action input (i.e. for critic networks).
|
||||
if action is not None:
|
||||
proprio_input.append(action)
|
||||
|
||||
proprio_input = tf2_utils.batch_concat(proprio_input)
|
||||
proprio_state = self._proprio_encoder(proprio_input)
|
||||
|
||||
return proprio_state
|
||||
Reference in New Issue
Block a user