mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-10 05:17:46 +08:00
9e3d04c867
PiperOrigin-RevId: 272089371
151 lines
4.7 KiB
Python
151 lines
4.7 KiB
Python
# Copyright 2019 DeepMind Technologies Limited.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
"""Vanilla Q-Learning agent."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import numpy as np
|
|
from six.moves import range
|
|
|
|
|
|
class EpsilonGreedyPolicy(object):
|
|
"""Epsilon greedy policy for table value function lookup."""
|
|
|
|
def __init__(self, value_function, actions):
|
|
"""Construct an epsilon greedy policy object.
|
|
|
|
Args:
|
|
value_function: agent value function as a dict.
|
|
actions: list of possible actions.
|
|
|
|
Raises:
|
|
ValueError: if `actions` agument is not an iterable.
|
|
"""
|
|
if not isinstance(actions, collections.Iterable):
|
|
raise ValueError('`actions` argument must be an iterable.')
|
|
|
|
self._value_function = value_function
|
|
self._actions = actions
|
|
|
|
def get_action(self, epsilon, state):
|
|
"""Get action following the e-greedy policy.
|
|
|
|
Args:
|
|
epsilon: probability of selecting a random action
|
|
state: current state of the game as a state/action tuple.
|
|
|
|
Returns:
|
|
Chosen action.
|
|
"""
|
|
if np.random.random() < epsilon:
|
|
return np.random.choice(self._actions)
|
|
else:
|
|
values = [self._value_function[(state, action)]
|
|
for action in self._actions]
|
|
|
|
max_value = max(values)
|
|
max_indices = [i for i, value in enumerate(values) if value == max_value]
|
|
|
|
return self._actions[np.random.choice(max_indices)]
|
|
|
|
|
|
class QLearning(object):
|
|
"""Q-learning agent."""
|
|
|
|
def __init__(self, actions, alpha=0.1, epsilon=0.1, q_initialisation=0.0,
|
|
discount=0.99):
|
|
"""Create a Q-learning agent.
|
|
|
|
Args:
|
|
actions: a BoundedArraySpec that specifes full discrete action spec.
|
|
alpha: agent learning rate.
|
|
epsilon: agent exploration rate.
|
|
q_initialisation: float, used to initialise the value function.
|
|
discount: discount factor for rewards.
|
|
"""
|
|
|
|
self._value_function = collections.defaultdict(lambda: q_initialisation)
|
|
self._valid_actions = list(range(actions.minimum, actions.maximum + 1))
|
|
self._policy = EpsilonGreedyPolicy(self._value_function,
|
|
self._valid_actions)
|
|
|
|
# Hyperparameters.
|
|
self.alpha = alpha
|
|
self.epsilon = epsilon
|
|
self.discount = discount
|
|
|
|
# Episode internal variables.
|
|
self._current_action = None
|
|
self._current_state = None
|
|
|
|
def begin_episode(self):
|
|
"""Perform episode initialisation."""
|
|
self._current_state = None
|
|
self._current_action = None
|
|
|
|
def _timestep_to_state(self, timestep):
|
|
return tuple(map(tuple, np.copy(timestep.observation['board'])))
|
|
|
|
def step(self, timestep):
|
|
"""Perform a single step in the environment."""
|
|
# Get state observations.
|
|
state = self._timestep_to_state(timestep)
|
|
|
|
# This is one of the follow up states (i.e. not the initial state).
|
|
if self._current_state is not None:
|
|
self._update(timestep, state)
|
|
|
|
self._current_state = state
|
|
# Determine action.
|
|
self._current_action = self._policy.get_action(self.epsilon, state)
|
|
# Emit action.
|
|
return self._current_action
|
|
|
|
def _calculate_reward(self, timestep, unused_state):
|
|
"""Calculate reward: to be extended when impact penalty is added."""
|
|
reward = timestep.reward
|
|
return reward
|
|
|
|
def _update(self, timestep, state):
|
|
"""Perform value function update."""
|
|
|
|
reward = self._calculate_reward(timestep, state)
|
|
|
|
# Terminal state.
|
|
if not state:
|
|
delta = (reward - self._value_function[(self._current_state,
|
|
self._current_action)])
|
|
# Non-terminal state.
|
|
else:
|
|
max_action = self._policy.get_action(0, state)
|
|
delta = (
|
|
reward + self.discount * self._value_function[(state, max_action)] -
|
|
self._value_function[(self._current_state, self._current_action)])
|
|
|
|
self._value_function[(self._current_state,
|
|
self._current_action)] += self.alpha * delta
|
|
|
|
def end_episode(self, timestep):
|
|
"""Performs episode cleanup."""
|
|
# Update for the terminal state.
|
|
self._update(timestep, None)
|
|
|
|
@property
|
|
def value_function(self):
|
|
return self._value_function
|