diff --git a/rl_unplugged/dm_control_suite.py b/rl_unplugged/dm_control_suite.py index b59939e..5360e03 100644 --- a/rl_unplugged/dm_control_suite.py +++ b/rl_unplugged/dm_control_suite.py @@ -30,7 +30,7 @@ observations. import collections import functools import os -from typing import Dict, Tuple, Set +from typing import Dict, Optional, Tuple, Set from acme import wrappers from dm_control import composer @@ -778,7 +778,7 @@ def dataset(root_path: str, shapes: Dict[str, Tuple[int]], num_threads: int, batch_size: int, - uint8_features: Set[str] = None, + uint8_features: Optional[Set[str]] = None, num_shards: int = 100, shuffle_buffer_size: int = 100000, sarsa: bool = True) -> tf.data.Dataset: