Add checkpoints from the ablation study.

PiperOrigin-RevId: 328023346
This commit is contained in:
Florent Altché
2020-08-23 14:26:26 +01:00
committed by Diego de Las Casas
parent 22c3daff19
commit 8457046b2c
33 changed files with 397 additions and 363 deletions
+3 -3
View File
@@ -82,7 +82,7 @@ class Agent():
def option_values(values, policy):
return tf.tensordot(
values[:, policy, Ellipsis], self._policy_weights[policy], axes=[1, 0])
values[:, policy, ...], self._policy_weights[policy], axes=[1, 0])
# Placeholders for policy.
o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
@@ -103,8 +103,8 @@ class Agent():
qo_t = option_values(q_t, p)
a_t = tf.cast(tf.argmax(qo_t, axis=-1), tf.int32)
qa_tm1 = _batched_index(q_tm1[:, p, Ellipsis], a_tm1)
qa_t = _batched_index(q_t[:, p, Ellipsis], a_t)
qa_tm1 = _batched_index(q_tm1[:, p, ...], a_tm1)
qa_t = _batched_index(q_t[:, p, ...], a_t)
# TD error
g = additional_discount * tf.expand_dims(d_t, axis=-1)