mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
b8cf44502b
PiperOrigin-RevId: 288392606
158 lines
5.3 KiB
Python
158 lines
5.3 KiB
Python
# Lint as: python2, python3
|
|
# pylint: disable=g-bad-file-header
|
|
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Loss functions."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import six
|
|
import tensorflow.compat.v1 as tf
|
|
|
|
|
|
def sum_time_average_batch(tensor, name=None):
|
|
"""Computes the mean over B assuming tensor is of shape [T, B]."""
|
|
tensor.get_shape().assert_has_rank(2)
|
|
return tf.reduce_mean(tf.reduce_sum(tensor, axis=0), axis=0, name=name)
|
|
|
|
|
|
def combine_logged_values(*logged_values_dicts):
|
|
"""Combine logged values dicts. Throws if there are any repeated keys."""
|
|
combined_dict = dict()
|
|
for logged_values in logged_values_dicts:
|
|
for k, v in six.iteritems(logged_values):
|
|
if k in combined_dict:
|
|
raise ValueError('Key "%s" is repeated in loss logging.' % k)
|
|
combined_dict[k] = v
|
|
return combined_dict
|
|
|
|
|
|
def reconstruction_losses(
|
|
recons,
|
|
targets,
|
|
image_cost,
|
|
action_cost,
|
|
reward_cost):
|
|
"""Reconstruction losses."""
|
|
if image_cost > 0.0:
|
|
# Neg log prob of obs image given Bernoulli(recon image) distribution.
|
|
negative_image_log_prob = tf.nn.sigmoid_cross_entropy_with_logits(
|
|
labels=targets.image, logits=recons.image)
|
|
nll_per_time = tf.reduce_sum(negative_image_log_prob, [-3, -2, -1])
|
|
image_loss = image_cost * nll_per_time
|
|
image_loss = sum_time_average_batch(image_loss)
|
|
else:
|
|
image_loss = tf.constant(0.)
|
|
|
|
if action_cost > 0.0 and recons.last_action is not tuple():
|
|
# Labels have shape (T, B), logits have shape (T, B, num_actions).
|
|
action_loss = action_cost * tf.nn.sparse_softmax_cross_entropy_with_logits(
|
|
labels=targets.last_action, logits=recons.last_action)
|
|
action_loss = sum_time_average_batch(action_loss)
|
|
else:
|
|
action_loss = tf.constant(0.)
|
|
|
|
if reward_cost > 0.0 and recons.last_reward is not tuple():
|
|
# MSE loss for reward.
|
|
recon_last_reward = recons.last_reward
|
|
recon_last_reward = tf.squeeze(recon_last_reward, -1)
|
|
reward_loss = 0.5 * reward_cost * tf.square(
|
|
recon_last_reward - targets.last_reward)
|
|
reward_loss = sum_time_average_batch(reward_loss)
|
|
else:
|
|
reward_loss = tf.constant(0.)
|
|
|
|
total_loss = image_loss + action_loss + reward_loss
|
|
|
|
logged_values = dict(
|
|
recon_loss_image=image_loss,
|
|
recon_loss_action=action_loss,
|
|
recon_loss_reward=reward_loss,
|
|
total_reconstruction_loss=total_loss,)
|
|
|
|
return total_loss, logged_values
|
|
|
|
|
|
def read_regularization_loss(
|
|
read_info,
|
|
strength_cost,
|
|
strength_tolerance,
|
|
strength_reg_mode,
|
|
key_norm_cost,
|
|
key_norm_tolerance):
|
|
"""Computes the sum of read strength and read key regularization losses."""
|
|
|
|
if (strength_cost <= 0.) and (key_norm_cost <= 0.):
|
|
read_reg_loss = tf.constant(0.)
|
|
return read_reg_loss, dict(read_regularization_loss=read_reg_loss)
|
|
|
|
if hasattr(read_info, 'read_strengths'):
|
|
read_strengths = read_info.read_strengths
|
|
read_keys = read_info.read_keys
|
|
else:
|
|
read_strengths = read_info.strengths
|
|
read_keys = read_info.keys
|
|
|
|
if read_info == tuple():
|
|
raise ValueError('Make sure read regularization costs are zero when '
|
|
'not outputting read info.')
|
|
|
|
read_reg_loss = tf.constant(0.)
|
|
if strength_cost > 0.:
|
|
strength_hinged = tf.maximum(strength_tolerance, read_strengths)
|
|
if strength_reg_mode == 'L2':
|
|
strength_loss = 0.5 * tf.square(strength_hinged)
|
|
elif strength_reg_mode == 'L1':
|
|
# Read strengths are always positive.
|
|
strength_loss = strength_hinged
|
|
else:
|
|
raise ValueError(
|
|
'Strength regularization mode "{}" is not supported.'.format(
|
|
strength_reg_mode))
|
|
|
|
# Sum across read heads to reduce from [T, B, n_reads] to [T, B].
|
|
strength_loss = strength_cost * tf.reduce_sum(strength_loss, axis=2)
|
|
|
|
if key_norm_cost > 0.:
|
|
key_norm_norms = tf.norm(read_keys, axis=-1)
|
|
key_norm_norms_hinged = tf.maximum(key_norm_tolerance, key_norm_norms)
|
|
key_norm_loss = 0.5 * tf.square(key_norm_norms_hinged)
|
|
|
|
# Sum across read heads to reduce from [T, B, n_reads] to [T, B].
|
|
key_norm_loss = key_norm_cost * tf.reduce_sum(key_norm_loss, axis=2)
|
|
|
|
read_reg_loss += key_norm_cost * key_norm_loss
|
|
|
|
if strength_cost > 0.:
|
|
strength_loss = sum_time_average_batch(strength_loss)
|
|
else:
|
|
strength_loss = tf.constant(0.)
|
|
|
|
if key_norm_cost > 0.:
|
|
key_norm_loss = sum_time_average_batch(key_norm_loss)
|
|
else:
|
|
key_norm_loss = tf.constant(0.)
|
|
|
|
read_reg_loss = strength_loss + key_norm_loss
|
|
|
|
logged_values = dict(
|
|
read_reg_strength_loss=strength_loss,
|
|
read_reg_key_norm_loss=key_norm_loss,
|
|
total_read_reg_loss=read_reg_loss)
|
|
|
|
return read_reg_loss, logged_values
|