Files
deepmind-research/tvt/losses.py
T
2020-01-15 18:20:53 +00:00

158 lines
5.3 KiB
Python

# Lint as: python2, python3
# pylint: disable=g-bad-file-header
# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Loss functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import tensorflow.compat.v1 as tf
def sum_time_average_batch(tensor, name=None):
"""Computes the mean over B assuming tensor is of shape [T, B]."""
tensor.get_shape().assert_has_rank(2)
return tf.reduce_mean(tf.reduce_sum(tensor, axis=0), axis=0, name=name)
def combine_logged_values(*logged_values_dicts):
"""Combine logged values dicts. Throws if there are any repeated keys."""
combined_dict = dict()
for logged_values in logged_values_dicts:
for k, v in six.iteritems(logged_values):
if k in combined_dict:
raise ValueError('Key "%s" is repeated in loss logging.' % k)
combined_dict[k] = v
return combined_dict
def reconstruction_losses(
recons,
targets,
image_cost,
action_cost,
reward_cost):
"""Reconstruction losses."""
if image_cost > 0.0:
# Neg log prob of obs image given Bernoulli(recon image) distribution.
negative_image_log_prob = tf.nn.sigmoid_cross_entropy_with_logits(
labels=targets.image, logits=recons.image)
nll_per_time = tf.reduce_sum(negative_image_log_prob, [-3, -2, -1])
image_loss = image_cost * nll_per_time
image_loss = sum_time_average_batch(image_loss)
else:
image_loss = tf.constant(0.)
if action_cost > 0.0 and recons.last_action is not tuple():
# Labels have shape (T, B), logits have shape (T, B, num_actions).
action_loss = action_cost * tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=targets.last_action, logits=recons.last_action)
action_loss = sum_time_average_batch(action_loss)
else:
action_loss = tf.constant(0.)
if reward_cost > 0.0 and recons.last_reward is not tuple():
# MSE loss for reward.
recon_last_reward = recons.last_reward
recon_last_reward = tf.squeeze(recon_last_reward, -1)
reward_loss = 0.5 * reward_cost * tf.square(
recon_last_reward - targets.last_reward)
reward_loss = sum_time_average_batch(reward_loss)
else:
reward_loss = tf.constant(0.)
total_loss = image_loss + action_loss + reward_loss
logged_values = dict(
recon_loss_image=image_loss,
recon_loss_action=action_loss,
recon_loss_reward=reward_loss,
total_reconstruction_loss=total_loss,)
return total_loss, logged_values
def read_regularization_loss(
read_info,
strength_cost,
strength_tolerance,
strength_reg_mode,
key_norm_cost,
key_norm_tolerance):
"""Computes the sum of read strength and read key regularization losses."""
if (strength_cost <= 0.) and (key_norm_cost <= 0.):
read_reg_loss = tf.constant(0.)
return read_reg_loss, dict(read_regularization_loss=read_reg_loss)
if hasattr(read_info, 'read_strengths'):
read_strengths = read_info.read_strengths
read_keys = read_info.read_keys
else:
read_strengths = read_info.strengths
read_keys = read_info.keys
if read_info == tuple():
raise ValueError('Make sure read regularization costs are zero when '
'not outputting read info.')
read_reg_loss = tf.constant(0.)
if strength_cost > 0.:
strength_hinged = tf.maximum(strength_tolerance, read_strengths)
if strength_reg_mode == 'L2':
strength_loss = 0.5 * tf.square(strength_hinged)
elif strength_reg_mode == 'L1':
# Read strengths are always positive.
strength_loss = strength_hinged
else:
raise ValueError(
'Strength regularization mode "{}" is not supported.'.format(
strength_reg_mode))
# Sum across read heads to reduce from [T, B, n_reads] to [T, B].
strength_loss = strength_cost * tf.reduce_sum(strength_loss, axis=2)
if key_norm_cost > 0.:
key_norm_norms = tf.norm(read_keys, axis=-1)
key_norm_norms_hinged = tf.maximum(key_norm_tolerance, key_norm_norms)
key_norm_loss = 0.5 * tf.square(key_norm_norms_hinged)
# Sum across read heads to reduce from [T, B, n_reads] to [T, B].
key_norm_loss = key_norm_cost * tf.reduce_sum(key_norm_loss, axis=2)
read_reg_loss += key_norm_cost * key_norm_loss
if strength_cost > 0.:
strength_loss = sum_time_average_batch(strength_loss)
else:
strength_loss = tf.constant(0.)
if key_norm_cost > 0.:
key_norm_loss = sum_time_average_batch(key_norm_loss)
else:
key_norm_loss = tf.constant(0.)
read_reg_loss = strength_loss + key_norm_loss
logged_values = dict(
read_reg_strength_loss=strength_loss,
read_reg_key_norm_loss=key_norm_loss,
total_read_reg_loss=read_reg_loss)
return read_reg_loss, logged_values