Add checkpoints from the ablation study.

PiperOrigin-RevId: 328023346
This commit is contained in:
Florent Altché
2020-08-23 14:26:26 +01:00
committed by Diego de Las Casas
parent 22c3daff19
commit 8457046b2c
33 changed files with 397 additions and 363 deletions
+16 -16
View File
@@ -106,7 +106,7 @@ TESTING_SUITE = [
ALL = TUNING_SUITE + TESTING_SUITE
def _decode_frames(pngs):
def _decode_frames(pngs: tf.Tensor):
"""Decode PNGs.
Args:
@@ -122,13 +122,13 @@ def _decode_frames(pngs):
return frames
def _make_reverb_sample(o_t,
a_t,
r_t,
d_t,
o_tp1,
a_tp1,
extras):
def _make_reverb_sample(o_t: tf.Tensor,
a_t: tf.Tensor,
r_t: tf.Tensor,
d_t: tf.Tensor,
o_tp1: tf.Tensor,
a_tp1: tf.Tensor,
extras: Dict[str, tf.Tensor]) -> reverb.ReplaySample:
"""Create Reverb sample with offline data.
Args:
@@ -151,8 +151,8 @@ def _make_reverb_sample(o_t,
return reverb.ReplaySample(info=info, data=data)
def _tf_example_to_reverb_sample(tf_example
):
def _tf_example_to_reverb_sample(tf_example: tf.train.Example
) -> reverb.ReplaySample:
"""Create a Reverb replay sample from a TF example."""
# Parse tf.Example.
@@ -184,11 +184,11 @@ def _tf_example_to_reverb_sample(tf_example
extras)
def dataset(path,
game,
run,
num_shards = 100,
shuffle_buffer_size = 100000):
def dataset(path: str,
game: str,
run: int,
num_shards: int = 100,
shuffle_buffer_size: int = 100000) -> tf.data.Dataset:
"""TF dataset of Atari SARSA tuples."""
path = os.path.join(path, f'{game}/run_{run}')
filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(num_shards)]
@@ -243,7 +243,7 @@ class AtariDopamineWrapper(dm_env.Environment):
return specs.DiscreteArray(self._env.action_space.n)
def environment(game):
def environment(game: str) -> dm_env.Environment:
"""Atari environment."""
env = atari_lib.create_atari_environment(game_name=game,
sticky_actions=True)