mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Add checkpoints from the ablation study.
PiperOrigin-RevId: 328023346
This commit is contained in:
committed by
Diego de Las Casas
parent
22c3daff19
commit
8457046b2c
@@ -82,7 +82,7 @@ class Agent():
|
||||
|
||||
def option_values(values, policy):
|
||||
return tf.tensordot(
|
||||
values[:, policy, Ellipsis], self._policy_weights[policy], axes=[1, 0])
|
||||
values[:, policy, ...], self._policy_weights[policy], axes=[1, 0])
|
||||
|
||||
# Placeholders for policy.
|
||||
o = tf.placeholder(shape=obs_spec.shape, dtype=obs_spec.dtype)
|
||||
@@ -103,8 +103,8 @@ class Agent():
|
||||
qo_t = option_values(q_t, p)
|
||||
|
||||
a_t = tf.cast(tf.argmax(qo_t, axis=-1), tf.int32)
|
||||
qa_tm1 = _batched_index(q_tm1[:, p, Ellipsis], a_tm1)
|
||||
qa_t = _batched_index(q_t[:, p, Ellipsis], a_t)
|
||||
qa_tm1 = _batched_index(q_tm1[:, p, ...], a_tm1)
|
||||
qa_t = _batched_index(q_t[:, p, ...], a_t)
|
||||
|
||||
# TD error
|
||||
g = additional_discount * tf.expand_dims(d_t, axis=-1)
|
||||
|
||||
Reference in New Issue
Block a user