mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-28 02:35:47 +08:00
Add checkpoints from the ablation study.
PiperOrigin-RevId: 328023346
This commit is contained in:
committed by
Diego de Las Casas
parent
22c3daff19
commit
8457046b2c
+16
-16
@@ -106,7 +106,7 @@ TESTING_SUITE = [
|
||||
ALL = TUNING_SUITE + TESTING_SUITE
|
||||
|
||||
|
||||
def _decode_frames(pngs):
|
||||
def _decode_frames(pngs: tf.Tensor):
|
||||
"""Decode PNGs.
|
||||
|
||||
Args:
|
||||
@@ -122,13 +122,13 @@ def _decode_frames(pngs):
|
||||
return frames
|
||||
|
||||
|
||||
def _make_reverb_sample(o_t,
|
||||
a_t,
|
||||
r_t,
|
||||
d_t,
|
||||
o_tp1,
|
||||
a_tp1,
|
||||
extras):
|
||||
def _make_reverb_sample(o_t: tf.Tensor,
|
||||
a_t: tf.Tensor,
|
||||
r_t: tf.Tensor,
|
||||
d_t: tf.Tensor,
|
||||
o_tp1: tf.Tensor,
|
||||
a_tp1: tf.Tensor,
|
||||
extras: Dict[str, tf.Tensor]) -> reverb.ReplaySample:
|
||||
"""Create Reverb sample with offline data.
|
||||
|
||||
Args:
|
||||
@@ -151,8 +151,8 @@ def _make_reverb_sample(o_t,
|
||||
return reverb.ReplaySample(info=info, data=data)
|
||||
|
||||
|
||||
def _tf_example_to_reverb_sample(tf_example
|
||||
):
|
||||
def _tf_example_to_reverb_sample(tf_example: tf.train.Example
|
||||
) -> reverb.ReplaySample:
|
||||
"""Create a Reverb replay sample from a TF example."""
|
||||
|
||||
# Parse tf.Example.
|
||||
@@ -184,11 +184,11 @@ def _tf_example_to_reverb_sample(tf_example
|
||||
extras)
|
||||
|
||||
|
||||
def dataset(path,
|
||||
game,
|
||||
run,
|
||||
num_shards = 100,
|
||||
shuffle_buffer_size = 100000):
|
||||
def dataset(path: str,
|
||||
game: str,
|
||||
run: int,
|
||||
num_shards: int = 100,
|
||||
shuffle_buffer_size: int = 100000) -> tf.data.Dataset:
|
||||
"""TF dataset of Atari SARSA tuples."""
|
||||
path = os.path.join(path, f'{game}/run_{run}')
|
||||
filenames = [f'{path}-{i:05d}-of-{num_shards:05d}' for i in range(num_shards)]
|
||||
@@ -243,7 +243,7 @@ class AtariDopamineWrapper(dm_env.Environment):
|
||||
return specs.DiscreteArray(self._env.action_space.n)
|
||||
|
||||
|
||||
def environment(game):
|
||||
def environment(game: str) -> dm_env.Environment:
|
||||
"""Atari environment."""
|
||||
env = atari_lib.create_atari_environment(game_name=game,
|
||||
sticky_actions=True)
|
||||
|
||||
Reference in New Issue
Block a user