mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-06 03:32:18 +08:00
Internal change
PiperOrigin-RevId: 417034217
This commit is contained in:
committed by
Diego de Las Casas
parent
fba48d1e44
commit
3f0d5ed1a0
59
avae/README.md
Normal file
59
avae/README.md
Normal file
@@ -0,0 +1,59 @@
|
||||
# The Autoencoding Variational Autoencoder
|
||||
|
||||
This is the code for the models in NeurIPS Submission [AVAE](https://papers.nips.cc/paper/2020/file/ac10ff1941c540cd87c107330996f4f6-Paper.pdf)
|
||||
|
||||
Folder contains code to train AVAE model in JAX, and we will be uploading
|
||||
evaluation setup soon.
|
||||
|
||||
Code files in the folder
|
||||
- checkpointer.py: Checkpointing abstraction
|
||||
- data_iterators.py: Datasets to be used
|
||||
- decoders.py: VAE decoder network architectures
|
||||
- encoders.py: VAE encoder network architectures
|
||||
- kl.py: KL computation between 2 gaussians
|
||||
- train.py: Function to train given ELBO, network and data
|
||||
- train_main.py: Main file to train AVAE
|
||||
- vae.py: VAE model defining various ELBOs
|
||||
|
||||
## Setup
|
||||
|
||||
To set up a Python3 virtual environment with the required dependencies, run:
|
||||
|
||||
```shell
|
||||
python -m venv avae_env
|
||||
source avae_env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r avae/requirements.txt
|
||||
```
|
||||
|
||||
## Running AVAE training
|
||||
|
||||
Following command will run AVAE training for ColorMnist dataset using MLP
|
||||
network architectures.
|
||||
|
||||
```shell
|
||||
python -m avae.train_main \
|
||||
--dataset='color_mnist' \
|
||||
--latent_dim=64 \
|
||||
--checkpoint_dir='/tmp/avae_checkpoints' \
|
||||
--checkpoint_filename='color_mnist_mlp_avae' \
|
||||
--rho=0.975 \
|
||||
--encoder='color_mnist_mlp_encoder' \
|
||||
--decoder='color_mnist_mlp_decoder'
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
### Citing our work
|
||||
|
||||
If you use that code for your research, please consider citing our paper:
|
||||
|
||||
```bibtex
|
||||
@article{cemgil2020autoencoding,
|
||||
title={The Autoencoding Variational Autoencoder},
|
||||
author={Cemgil, Taylan and Ghaisas, Sumedh and Dvijotham, Krishnamurthy and Gowal, Sven and Kohli, Pushmeet},
|
||||
journal={Advances in Neural Information Processing Systems},
|
||||
volume={33},
|
||||
year={2020}
|
||||
}
|
||||
```
|
||||
BIN
avae/avae_image.png
Normal file
BIN
avae/avae_image.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 164 KiB |
88
avae/checkpointer.py
Normal file
88
avae/checkpointer.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Checkpointing functionality."""
|
||||
|
||||
import os
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from absl import logging
|
||||
import dill
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class Checkpointer:
|
||||
"""A checkpoint saving and loading class."""
|
||||
|
||||
def __init__(self, checkpoint_dir: str, filename: str):
|
||||
"""Class initializer.
|
||||
|
||||
Args:
|
||||
checkpoint_dir: Checkpoint directory.
|
||||
filename: Filename of checkpoint in checkpoint directory.
|
||||
"""
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
if not os.path.isdir(self._checkpoint_dir):
|
||||
os.mkdir(self._checkpoint_dir)
|
||||
|
||||
self._checkpoint_path = os.path.join(self._checkpoint_dir, filename)
|
||||
|
||||
def save_checkpoint(
|
||||
self,
|
||||
experiment_state: Mapping[str, jnp.ndarray],
|
||||
opt_state: Mapping[str, jnp.ndarray],
|
||||
step: int,
|
||||
extra_checkpoint_info: Optional[Mapping[str, Any]] = None) -> None:
|
||||
"""Save checkpoint with experiment state and step information.
|
||||
|
||||
Args:
|
||||
experiment_state: Experiment params to be stored.
|
||||
opt_state: Optimizer state to be stored.
|
||||
step: Training iteration step.
|
||||
extra_checkpoint_info: Extra information to be stored.
|
||||
"""
|
||||
if jax.host_id() != 0:
|
||||
return
|
||||
|
||||
checkpoint_data = dict(
|
||||
experiment_state=jax.tree_map(jax.device_get, experiment_state),
|
||||
opt_state=jax.tree_map(jax.device_get, opt_state),
|
||||
step=step)
|
||||
if extra_checkpoint_info is not None:
|
||||
for key in extra_checkpoint_info:
|
||||
checkpoint_data[key] = extra_checkpoint_info[key]
|
||||
|
||||
with open(self._checkpoint_path, 'wb') as checkpoint_file:
|
||||
dill.dump(checkpoint_data, checkpoint_file, protocol=2)
|
||||
|
||||
def load_checkpoint(
|
||||
self) -> Optional[Mapping[str, Mapping[str, jnp.ndarray]]]:
|
||||
"""Load and return checkpoint data.
|
||||
|
||||
Returns:
|
||||
Loaded checkpoint if it exists else returns None.
|
||||
"""
|
||||
|
||||
if os.path.exists(self._checkpoint_path):
|
||||
with open(self._checkpoint_path, 'rb') as checkpoint_file:
|
||||
checkpoint_data = dill.load(checkpoint_file)
|
||||
logging.info('Loading checkpoint from %s, saved at step %d',
|
||||
self._checkpoint_path, checkpoint_data['step'])
|
||||
return checkpoint_data
|
||||
else:
|
||||
logging.warning('No pre-saved checkpoint found at %s',
|
||||
self._checkpoint_path)
|
||||
return None
|
||||
97
avae/data_iterators.py
Normal file
97
avae/data_iterators.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataset iterators Mnist, ColorMnist."""
|
||||
|
||||
import enum
|
||||
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from avae import types
|
||||
|
||||
|
||||
class Dataset(enum.Enum):
|
||||
color_mnist = enum.auto()
|
||||
|
||||
|
||||
class MnistDataIterator(object):
|
||||
"""Mnist data iterator class.
|
||||
|
||||
Data is obtained as dataclass, types.LabelledData.
|
||||
"""
|
||||
|
||||
def __init__(self, subset: str, batch_size: int):
|
||||
"""Class initializer.
|
||||
|
||||
Args:
|
||||
subset: Subset of dataset.
|
||||
batch_size: Batch size of the returned dataset iterator.
|
||||
"""
|
||||
datasets = tfds.load('mnist')
|
||||
train_dataset = datasets[subset]
|
||||
def _map_fun(x):
|
||||
return {'data': tf.cast(x['image'], tf.float32) / 255.0,
|
||||
'label': x['label']}
|
||||
train_dataset = train_dataset.map(_map_fun)
|
||||
train_dataset = train_dataset.batch(batch_size,
|
||||
drop_remainder=True).repeat()
|
||||
self._iterator = iter(tfds.as_numpy(train_dataset))
|
||||
self._batch_size = batch_size
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> types.LabelledData:
|
||||
return types.LabelledData(**next(self._iterator))
|
||||
|
||||
|
||||
class ColorMnistDataIterator(MnistDataIterator):
|
||||
"""Color Mnist data iterator.
|
||||
|
||||
Each ColorMnist image is of shape (28, 28, 3). ColorMnist digit can have 7
|
||||
different colors by permution of RGB channels (turning on and off RGB
|
||||
channels, except for all channels off permutation).
|
||||
|
||||
Data is obtained as dataclass, util_dataclasses.LabelledData.
|
||||
"""
|
||||
|
||||
def __next__(self) -> types.LabelledData:
|
||||
mnist_batch = next(self._iterator)
|
||||
mnist_image = mnist_batch['data']
|
||||
# Colors are supported by turning off and on RGB channels.
|
||||
# Thus possible colors are
|
||||
# [black (excluded), red, green, yellow, blue, magenta, cyan, white]
|
||||
# color_id takes value from [1,8)
|
||||
color_id = np.random.randint(7, size=self._batch_size) + 1
|
||||
red_channel_bool = np.mod(color_id, 2)
|
||||
red_channel_bool = jnp.reshape(red_channel_bool, [-1, 1, 1, 1])
|
||||
blue_channel_bool = np.floor_divide(color_id, 4)
|
||||
blue_channel_bool = jnp.reshape(blue_channel_bool, [-1, 1, 1, 1])
|
||||
green_channel_bool = np.mod(np.floor_divide(color_id, 2), 2)
|
||||
green_channel_bool = jnp.reshape(green_channel_bool, [-1, 1, 1, 1])
|
||||
|
||||
color_mnist_image = jnp.stack([
|
||||
jnp.multiply(red_channel_bool, mnist_image),
|
||||
jnp.multiply(blue_channel_bool, mnist_image),
|
||||
jnp.multiply(green_channel_bool, mnist_image)], axis=3)
|
||||
color_mnist_image = jnp.reshape(color_mnist_image, [-1, 28, 28, 3])
|
||||
# Color id takes value [1, 8)
|
||||
# Although to make classification code easier `color` label attached to data
|
||||
# takes value in [0, 7) (by subtracting 1 from color id)
|
||||
return types.LabelledData(
|
||||
data=color_mnist_image, label=mnist_batch['label'])
|
||||
88
avae/decoders.py
Normal file
88
avae/decoders.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Decoder architectures to be used with VAE."""
|
||||
|
||||
import abc
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DecoderBase(hk.Module):
|
||||
"""Base class for decoder network classes."""
|
||||
|
||||
def __init__(self, obs_var: float):
|
||||
"""Class initializer.
|
||||
|
||||
Args:
|
||||
obs_var: oversation variance of the dataset.
|
||||
"""
|
||||
super().__init__()
|
||||
self._obs_var = obs_var
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Reconstruct from a given latent sample.
|
||||
|
||||
Args:
|
||||
z: latent samples of shape (batch_size, latent_dim)
|
||||
Returns:
|
||||
Reconstruction with shape (batch_size, ...).
|
||||
"""
|
||||
|
||||
def data_fidelity(
|
||||
self,
|
||||
input_data: jnp.ndarray,
|
||||
recons: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
"""Compute Data fidelity (recons loss) for given input and recons.
|
||||
|
||||
Args:
|
||||
input_data: Input batch of shape (batch_size, ...).
|
||||
recons: Reconstruction of the input data. An array with the same shape as
|
||||
`input_data.data`.
|
||||
Returns:
|
||||
Computed data fidelity term across batch of data. An array of shape
|
||||
`(batch_size,)`.
|
||||
"""
|
||||
error = (input_data - recons).reshape(input_data.shape[0], -1)
|
||||
return -0.5 * jnp.sum(jnp.square(error), axis=1) / self._obs_var
|
||||
|
||||
|
||||
class ColorMnistMLPDecoder(DecoderBase):
|
||||
"""MLP decoder for Color Mnist."""
|
||||
|
||||
_hidden_units = (200, 200, 200, 200)
|
||||
_image_dims = (28, 28, 3) # Dimensions of a single MNIST image.
|
||||
|
||||
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Reconstruct with given latent sample.
|
||||
|
||||
Args:
|
||||
z: latent samples of shape (batch_size, latent_dim)
|
||||
Returns:
|
||||
Reconstructions data of shape (batch_size, 28, 28, 3).
|
||||
"""
|
||||
out = z
|
||||
for units in self._hidden_units:
|
||||
out = hk.Linear(units)(out)
|
||||
out = jax.nn.relu(out)
|
||||
out = hk.Linear(np.product(self._image_dims))(out)
|
||||
out = jax.nn.sigmoid(out)
|
||||
return jnp.reshape(out, (-1,) + self._image_dims)
|
||||
|
||||
112
avae/encoders.py
Normal file
112
avae/encoders.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Encoder architectures to be used with VAE."""
|
||||
|
||||
import abc
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from avae import types
|
||||
|
||||
_Params = TypeVar('_Params')
|
||||
|
||||
|
||||
class EncoderBase(hk.Module, Generic[_Params]):
|
||||
"""Abstract class for encoder architectures."""
|
||||
|
||||
def __init__(self, latent_dim: int):
|
||||
"""Class initializer.
|
||||
|
||||
Args:
|
||||
latent_dim: Latent dimensions of the model.
|
||||
"""
|
||||
super().__init__()
|
||||
self._latent_dim = latent_dim
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, input_data: jnp.ndarray) -> _Params:
|
||||
"""Return posterior distribution over latents.
|
||||
|
||||
Args:
|
||||
input_data: Input batch of shape (batch_size, ...).
|
||||
|
||||
Returns:
|
||||
Parameters of the posterior distribution over the latents.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample(self, posterior: _Params, key: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Sample from the given posterior distribution.
|
||||
|
||||
Args:
|
||||
posterior: Parameters of posterior distribution over the latents.
|
||||
key: Random number generator key.
|
||||
|
||||
Returns:
|
||||
Sample from the posterior distribution over latents,
|
||||
shape[batch_size, latent_dim]
|
||||
"""
|
||||
|
||||
|
||||
class ColorMnistMLPEncoder(EncoderBase[types.NormalParams]):
|
||||
"""MLP encoder for ColorMnist."""
|
||||
|
||||
_hidden_units = (200, 200, 200, 200)
|
||||
|
||||
def __call__(
|
||||
self, input_data: jnp.ndarray) -> types.NormalParams:
|
||||
"""Return posterior distribution over latents.
|
||||
|
||||
Args:
|
||||
input_data: Input batch of shape (batch_size, ...).
|
||||
|
||||
Returns:
|
||||
Posterior distribution over the latents.
|
||||
"""
|
||||
out = hk.Flatten()(input_data)
|
||||
for units in self._hidden_units:
|
||||
out = hk.Linear(units)(out)
|
||||
out = jax.nn.relu(out)
|
||||
out = hk.Linear(2 * self._latent_dim)(out)
|
||||
return _normal_params_from_logits(out)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
posterior: types.NormalParams,
|
||||
key: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
"""Sample from the given normal posterior (mean, var) distribution.
|
||||
|
||||
Args:
|
||||
posterior: Posterior over the latents.
|
||||
key: Random number generator key.
|
||||
Returns:
|
||||
Sample from the posterior distribution over latents,
|
||||
shape[batch_size, latent_dim]
|
||||
"""
|
||||
eps = jax.random.normal(
|
||||
key, shape=(posterior.mean.shape[0], self._latent_dim))
|
||||
return posterior.mean + eps * posterior.variance
|
||||
|
||||
|
||||
def _normal_params_from_logits(
|
||||
logits: jnp.ndarray) -> types.NormalParams:
|
||||
"""Construct mean and variance of normal distribution from given logits."""
|
||||
mean, log_variance = jnp.split(logits, 2, axis=1)
|
||||
variance = jnp.exp(log_variance)
|
||||
return types.NormalParams(mean=mean, variance=variance)
|
||||
43
avae/kl.py
Normal file
43
avae/kl.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Various KL implementations in JAX."""
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def kl_p_with_uniform_normal(mean: jnp.ndarray,
|
||||
variance: jnp.ndarray) -> jnp.ndarray:
|
||||
r"""KL between p_dist with uniform normal prior.
|
||||
|
||||
Args:
|
||||
mean: Mean of the gaussian distribution, shape (latent_dims,)
|
||||
variance: Variance of the gaussian distribution, shape (latent_dims,)
|
||||
Returns:
|
||||
KL divergence KL(P||N(0, 1)) shape ()
|
||||
"""
|
||||
|
||||
if len(variance.shape) == 2:
|
||||
# If `variance` is a full covariance matrix
|
||||
variance_trace = jnp.trace(variance)
|
||||
_, ldet1 = jnp.linalg.slogdet(variance)
|
||||
else:
|
||||
variance_trace = jnp.sum(variance)
|
||||
ldet1 = jnp.sum(jnp.log(variance))
|
||||
|
||||
mean_contribution = jnp.sum(jnp.square(mean))
|
||||
res = -ldet1
|
||||
res += variance_trace + mean_contribution - mean.shape[0]
|
||||
return res * 0.5
|
||||
8
avae/requirements.txt
Normal file
8
avae/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
absl-py>0.9
|
||||
dill==0.3.2
|
||||
dm-haiku
|
||||
jax
|
||||
optax
|
||||
tensorflow
|
||||
tensorflow_datasets
|
||||
|
||||
31
avae/run.sh
Executable file
31
avae/run.sh
Executable file
@@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
python -m venv avae_env
|
||||
source avae_env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install -r avae/requirements.txt
|
||||
|
||||
# Start training AVAE model.
|
||||
python -m avae.train_main \
|
||||
--dataset='color_mnist' \
|
||||
--latent_dim=64 \
|
||||
--checkpoint_dir='/tmp/avae_checkpoints' \
|
||||
--checkpoint_filename='color_mnist_mlp_avae' \
|
||||
--rho=0.975 \
|
||||
--encoder=color_mnist_mlp_encoder \
|
||||
--decoder=color_mnist_mlp_decoder \
|
||||
--iterations=1
|
||||
117
avae/train.py
Normal file
117
avae/train.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""VAE style training."""
|
||||
|
||||
from typing import Any, Callable, Iterator, Sequence, Mapping, Tuple, Optional
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
|
||||
from avae import checkpointer
|
||||
from avae import types
|
||||
|
||||
|
||||
def train(
|
||||
train_data_iterator: Iterator[types.LabelledData],
|
||||
test_data_iterator: Iterator[types.LabelledData],
|
||||
elbo_fun: hk.Transformed,
|
||||
learning_rate: float,
|
||||
checkpoint_dir: str,
|
||||
checkpoint_filename: str,
|
||||
checkpoint_every: int,
|
||||
test_every: int,
|
||||
iterations: int,
|
||||
rng_seed: int,
|
||||
test_functions: Optional[Sequence[Callable[[Mapping[str, jnp.ndarray]],
|
||||
Tuple[str, float]]]] = None,
|
||||
extra_checkpoint_info: Optional[Mapping[str, Any]] = None):
|
||||
"""Train VAE with given data iterator and elbo definition.
|
||||
|
||||
Args:
|
||||
train_data_iterator: Iterator of batched training data.
|
||||
test_data_iterator: Iterator of batched testing data.
|
||||
elbo_fun: Haiku transfomed function returning elbo.
|
||||
learning_rate: Learning rate to be used with optimizer.
|
||||
checkpoint_dir: Path of the checkpoint directory.
|
||||
checkpoint_filename: Filename of the checkpoint.
|
||||
checkpoint_every: Checkpoint every N iterations.
|
||||
test_every: Test and log results every N iterations.
|
||||
iterations: Number of training iterations to perform.
|
||||
rng_seed: Seed for random number generator.
|
||||
test_functions: Test function iterable, each function takes test data and
|
||||
outputs extra info to print at test and log time.
|
||||
extra_checkpoint_info: Extra info to put inside saved checkpoint.
|
||||
"""
|
||||
rng_seq = hk.PRNGSequence(jax.random.PRNGKey(rng_seed))
|
||||
|
||||
opt_init, opt_update = optax.chain(
|
||||
# Set the parameters of Adam. Note the learning_rate is not here.
|
||||
optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
|
||||
# Put a minus sign to *minimise* the loss.
|
||||
optax.scale(-learning_rate))
|
||||
|
||||
@jax.jit
|
||||
def loss(params, key, data):
|
||||
elbo_outputs = elbo_fun.apply(params, key, data)
|
||||
return -jnp.mean(elbo_outputs.elbo)
|
||||
|
||||
@jax.jit
|
||||
def loss_test(params, key, data):
|
||||
elbo_output = elbo_fun.apply(params, key, data)
|
||||
return (-jnp.mean(elbo_output.elbo), jnp.mean(elbo_output.data_fidelity),
|
||||
jnp.mean(elbo_output.kl))
|
||||
|
||||
@jax.jit
|
||||
def update_step(params, key, data, opt_state):
|
||||
grads = jax.grad(loss, has_aux=False)(params, key, data)
|
||||
updates, opt_state = opt_update(grads, opt_state, params)
|
||||
params = optax.apply_updates(params, updates)
|
||||
return params, opt_state
|
||||
|
||||
exp_checkpointer = checkpointer.Checkpointer(
|
||||
checkpoint_dir, checkpoint_filename)
|
||||
experiment_data = exp_checkpointer.load_checkpoint()
|
||||
|
||||
if experiment_data is not None:
|
||||
start = experiment_data['step']
|
||||
params = experiment_data['experiment_state']
|
||||
opt_state = experiment_data['opt_state']
|
||||
else:
|
||||
start = 0
|
||||
params = elbo_fun.init(
|
||||
next(rng_seq), next(train_data_iterator).data)
|
||||
opt_state = opt_init(params)
|
||||
|
||||
for step in range(start, iterations, 1):
|
||||
if step % test_every == 0:
|
||||
test_loss, ll, kl = loss_test(params, next(rng_seq),
|
||||
next(test_data_iterator).data)
|
||||
output_message = (f'Step {step} elbo {-test_loss:0.2f} LL {ll:0.2f} '
|
||||
f'KL {kl:0.2f}')
|
||||
if test_functions:
|
||||
for test_function in test_functions:
|
||||
name, test_output = test_function(params)
|
||||
output_message += f' {name}: {test_output:0.2f}'
|
||||
print(output_message)
|
||||
|
||||
params, opt_state = update_step(params, next(rng_seq),
|
||||
next(train_data_iterator).data, opt_state)
|
||||
|
||||
if step % checkpoint_every == 0:
|
||||
exp_checkpointer.save_checkpoint(
|
||||
params, opt_state, step, extra_checkpoint_info)
|
||||
127
avae/train_main.py
Normal file
127
avae/train_main.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Main file for training VAE/AVAE."""
|
||||
|
||||
import enum
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import haiku as hk
|
||||
|
||||
from avae import data_iterators
|
||||
from avae import decoders
|
||||
from avae import encoders
|
||||
from avae import train
|
||||
from avae import vae
|
||||
|
||||
|
||||
class Model(enum.Enum):
|
||||
vae = enum.auto()
|
||||
avae = enum.auto()
|
||||
|
||||
|
||||
class EncoderArch(enum.Enum):
|
||||
color_mnist_mlp_encoder = 'ColorMnistMLPEncoder'
|
||||
|
||||
|
||||
class DecoderArch(enum.Enum):
|
||||
color_mnist_mlp_decoder = 'ColorMnistMLPDecoder'
|
||||
|
||||
|
||||
_DATASET = flags.DEFINE_enum_class(
|
||||
'dataset', data_iterators.Dataset.color_mnist, data_iterators.Dataset,
|
||||
'Dataset to train on')
|
||||
_LATENT_DIM = flags.DEFINE_integer('latent_dim', 32,
|
||||
'Number of latent dimensions.')
|
||||
_TRAIN_BATCH_SIZE = flags.DEFINE_integer('train_batch_size', 64,
|
||||
'Train batch size.')
|
||||
_TEST_BATCH_SIZE = flags.DEFINE_integer('test_batch_size', 64,
|
||||
'Testing batch size.')
|
||||
_TEST_EVERY = flags.DEFINE_integer('test_every', 1000,
|
||||
'Test every N iterations.')
|
||||
_ITERATIONS = flags.DEFINE_integer('iterations', 102000,
|
||||
'Number of training iterations.')
|
||||
_OBS_VAR = flags.DEFINE_float('obs_var', 0.5,
|
||||
'Observation variance of the data. (Default 0.5)')
|
||||
|
||||
|
||||
_MODEL = flags.DEFINE_enum_class('model', Model.avae, Model,
|
||||
'Model used for training.')
|
||||
_RHO = flags.DEFINE_float('rho', 0.8, 'Rho parameter used with AVAE or SE.')
|
||||
_LEARNING_RATE = flags.DEFINE_float('learning_rate', 1e-4, 'Learning rate.')
|
||||
_RNG_SEED = flags.DEFINE_integer('rng_seed', 0,
|
||||
'Seed for random number generator.')
|
||||
_CHECKPOINT_DIR = flags.DEFINE_string('checkpoint_dir', '/tmp/',
|
||||
'Directory for checkpointing.')
|
||||
_CHECKPOINT_FILENAME = flags.DEFINE_string(
|
||||
'checkpoint_filename', 'color_mnist_avae_mlp', 'Checkpoint filename.')
|
||||
_CHECKPOINT_EVERY = flags.DEFINE_integer(
|
||||
'checkpoint_every', 1000, 'Checkpoint every N steps.')
|
||||
_ENCODER = flags.DEFINE_enum_class(
|
||||
'encoder', EncoderArch.color_mnist_mlp_encoder, EncoderArch,
|
||||
'Encoder class name.')
|
||||
_DECODER = flags.DEFINE_enum_class(
|
||||
'decoder', DecoderArch.color_mnist_mlp_decoder, DecoderArch,
|
||||
'Decoder class name.')
|
||||
|
||||
|
||||
def main(_):
|
||||
if _DATASET.value is data_iterators.Dataset.color_mnist:
|
||||
train_data_iterator = iter(
|
||||
data_iterators.ColorMnistDataIterator('train', _TRAIN_BATCH_SIZE.value))
|
||||
test_data_iterator = iter(
|
||||
data_iterators.ColorMnistDataIterator('test', _TEST_BATCH_SIZE.value))
|
||||
|
||||
def _elbo_fun(input_data):
|
||||
if _ENCODER.value is EncoderArch.color_mnist_mlp_encoder:
|
||||
encoder = encoders.ColorMnistMLPEncoder(_LATENT_DIM.value)
|
||||
|
||||
if _DECODER.value is DecoderArch.color_mnist_mlp_decoder:
|
||||
decoder = decoders.ColorMnistMLPDecoder(_OBS_VAR.value)
|
||||
|
||||
vae_obj = vae.VAE(encoder, decoder, _RHO.value)
|
||||
|
||||
if _MODEL.value is Model.vae:
|
||||
return vae_obj.vae_elbo(input_data, hk.next_rng_key())
|
||||
else:
|
||||
return vae_obj.avae_elbo(input_data, hk.next_rng_key())
|
||||
|
||||
elbo_fun = hk.transform(_elbo_fun)
|
||||
|
||||
extra_checkpoint_info = {
|
||||
'dataset': _DATASET.value.name,
|
||||
'encoder': _ENCODER.value.name,
|
||||
'decoder': _DECODER.value.name,
|
||||
'obs_var': _OBS_VAR.value,
|
||||
'rho': _RHO.value,
|
||||
'latent_dim': _LATENT_DIM.value,
|
||||
}
|
||||
|
||||
train.train(
|
||||
train_data_iterator=train_data_iterator,
|
||||
test_data_iterator=test_data_iterator,
|
||||
elbo_fun=elbo_fun,
|
||||
learning_rate=_LEARNING_RATE.value,
|
||||
checkpoint_dir=_CHECKPOINT_DIR.value,
|
||||
checkpoint_filename=_CHECKPOINT_FILENAME.value,
|
||||
checkpoint_every=_CHECKPOINT_EVERY.value,
|
||||
test_every=_TEST_EVERY.value,
|
||||
iterations=_ITERATIONS.value,
|
||||
rng_seed=_RNG_SEED.value,
|
||||
extra_checkpoint_info=extra_checkpoint_info)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
53
avae/types.py
Normal file
53
avae/types.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Useful dataclasses types used across the code."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import dataclasses
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ELBOOutputs:
|
||||
elbo: jnp.ndarray
|
||||
data_fidelity: jnp.ndarray
|
||||
kl: jnp.ndarray
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LabelledData:
|
||||
"""A batch of labelled examples.
|
||||
|
||||
Attributes:
|
||||
data: Array of shape (batch_size, ...).
|
||||
label: Array of shape (batch_size, ...).
|
||||
"""
|
||||
data: np.ndarray
|
||||
label: Optional[np.ndarray]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NormalParams:
|
||||
"""Parameters of a normal distribution.
|
||||
|
||||
Attributes:
|
||||
mean: Array of shape (batch_size, latent_dim).
|
||||
variance: Array of shape (batch_size, latent_dim).
|
||||
"""
|
||||
mean: jnp.ndarray
|
||||
variance: jnp.ndarray
|
||||
133
avae/vae.py
Normal file
133
avae/vae.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# Copyright 2020 DeepMind Technologies Limited.
|
||||
#
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Standard VAE class."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from avae import decoders
|
||||
from avae import encoders
|
||||
from avae import kl
|
||||
from avae import types
|
||||
|
||||
|
||||
class VAE:
|
||||
"""VAE class.
|
||||
|
||||
This class defines the ELBO used in training VAE models. It also adds function
|
||||
for forward passing data through VAE.
|
||||
"""
|
||||
|
||||
def __init__(self, encoder: encoders.EncoderBase,
|
||||
decoder: decoders.DecoderBase, rho: Optional[float] = None):
|
||||
"""Class initializer.
|
||||
|
||||
Args:
|
||||
encoder: Encoder network architecture.
|
||||
decoder: Decoder network architecture.
|
||||
rho: Rho parameter used in AVAE training.
|
||||
"""
|
||||
self._encoder = encoder
|
||||
self._decoder = decoder
|
||||
self._rho = rho
|
||||
|
||||
def vae_elbo(
|
||||
self, input_data: jnp.ndarray,
|
||||
key: jnp.ndarray) -> types.ELBOOutputs:
|
||||
"""ELBO for training VAE.
|
||||
|
||||
Args:
|
||||
input_data: Input batch of shape (batch_size, ...).
|
||||
key: Key for random number generator.
|
||||
Returns:
|
||||
|
||||
Computed VAE Elbo as type util_dataclasses.ELBOOutputs
|
||||
"""
|
||||
posterior = self._encoder(input_data)
|
||||
samples = self._encoder.sample(posterior, key)
|
||||
kls = jax.vmap(kl.kl_p_with_uniform_normal, [0])(
|
||||
posterior.mean, posterior.variance)
|
||||
recons = self._decoder(samples)
|
||||
data_fidelity = self._decoder.data_fidelity(input_data, recons)
|
||||
elbo = data_fidelity - kls
|
||||
return types.ELBOOutputs(elbo, data_fidelity, kls)
|
||||
|
||||
def avae_elbo(
|
||||
self, input_data: jnp.ndarray,
|
||||
key: jnp.ndarray) -> types.ELBOOutputs:
|
||||
"""ELBO for training AVAE model.
|
||||
|
||||
Args:
|
||||
input_data: Input batch of shape (batch_size, ...).
|
||||
key: Key for random number generator.
|
||||
Returns:
|
||||
Computed AVAE Elbo in nested tuple (Elbo, (data_fidelity, KL)). All arrays
|
||||
have batch dimension intact.
|
||||
"""
|
||||
aux_images = jax.lax.stop_gradient(self(input_data, key))
|
||||
|
||||
posterior = self._encoder(input_data)
|
||||
samples = self._encoder.sample(posterior, key)
|
||||
kls = jax.vmap(kl.kl_p_with_uniform_normal, [0, 0])(
|
||||
posterior.mean, posterior.variance)
|
||||
recons = self._decoder(samples)
|
||||
data_fidelity = self._decoder.data_fidelity(input_data, recons)
|
||||
elbo = data_fidelity - kls
|
||||
|
||||
aux_posterior = self._encoder(aux_images)
|
||||
latent_mean = posterior.mean
|
||||
latent_var = posterior.variance
|
||||
aux_latent_mean = aux_posterior.mean
|
||||
aux_latent_var = aux_posterior.variance
|
||||
latent_dim = latent_mean.shape[1]
|
||||
|
||||
def _reduce(x):
|
||||
return jnp.mean(jnp.sum(x, axis=1))
|
||||
|
||||
# Computation of <log p(Z_aux | Z)>.
|
||||
expected_log_conditional = (
|
||||
aux_latent_var + jnp.square(self._rho) * latent_var +
|
||||
jnp.square(aux_latent_mean - self._rho * latent_mean))
|
||||
expected_log_conditional = _reduce(expected_log_conditional)
|
||||
expected_log_conditional /= 2.0 * (1.0 - jnp.square(self._rho))
|
||||
expected_log_conditional = (latent_dim *
|
||||
jnp.log(1.0 / (2 * jnp.pi)) -
|
||||
expected_log_conditional)
|
||||
elbo += expected_log_conditional
|
||||
# Entropy of Z_aux
|
||||
elbo += _reduce(0.5 * jnp.log(2 * jnp.pi * jnp.e * aux_latent_var))
|
||||
|
||||
return types.ELBOOutputs(elbo, data_fidelity, kls)
|
||||
|
||||
def __call__(
|
||||
self, input_data: jnp.ndarray,
|
||||
key: jnp.ndarray) -> jnp.ndarray:
|
||||
"""Reconstruction of the input data.
|
||||
|
||||
Args:
|
||||
input_data: Input batch of shape (batch_size, ...).
|
||||
key: Key for random number generator.
|
||||
Returns:
|
||||
Reconstruction of the input data as jnp.ndarray of shape
|
||||
[batch_dim, observation_dims].
|
||||
"""
|
||||
posterior = self._encoder(input_data)
|
||||
samples = self._encoder.sample(posterior, key)
|
||||
recons = self._decoder(samples)
|
||||
return recons
|
||||
|
||||
83
fusion_tcv/README.md
Normal file
83
fusion_tcv/README.md
Normal file
@@ -0,0 +1,83 @@
|
||||
# TCV Fusion Control Objectives
|
||||
|
||||
This code release contains the rewards, control targets, noise model and
|
||||
parameter variation for the paper *Magnetic control of tokamak plasmas through
|
||||
deep reinforcement learning*, published in Nature in ... 2021.
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This release is useful for understanding the details of specific elements of the
|
||||
learning architecture, however it does not contain the simulator
|
||||
([FGE](https://infoscience.epfl.ch/record/283775), part of
|
||||
[LIUQE](https://www.epfl.ch/research/domains/swiss-plasma-center/liuqe-suite/)),
|
||||
the trained/exported control policies, nor the agent training infrastructure.
|
||||
This release is useful for replicating our results which can be done by
|
||||
assembling all the components, many of which are open source elsewhere. Please
|
||||
see the "Code Availability" statement in the paper for more information.
|
||||
|
||||
The learning algorithm we used is
|
||||
[MPO](https://arxiv.org/abs/1812.02256), which has an open source
|
||||
[reference implementation](https://github.com/deepmind/acme) in
|
||||
[Acme](https://arxiv.org/abs/2006.00979). Additionally, the open source software
|
||||
libraries [launchpad](https://arxiv.org/abs/2106.04516)
|
||||
([code](https://github.com/deepmind/launchpad)),
|
||||
[dm_env](https://github.com/deepmind/dm_env),
|
||||
[sonnet](https://github.com/deepmind/sonnet),
|
||||
[tensorflow](https://arxiv.org/abs/1603.04467)
|
||||
([code](https://www.tensorflow.org/)) and
|
||||
[reverb](https://arxiv.org/pdf/2102.04736.pdf)
|
||||
([code](https://github.com/deepmind/reverb)) were used.
|
||||
FGE and LIUQE are available on request from the
|
||||
[Swiss Plasma Center](https://www.epfl.ch/research/domains/swiss-plasma-center/)
|
||||
at [EPFL](https://www.epfl.ch/en/)
|
||||
(email [Federico Felici](mailto:federico.felici@epfl.ch)), subject to agreement.
|
||||
|
||||
## Objectives used in published TCV experiments
|
||||
|
||||
Take a look at `rewards_used.py` and `references.py` for the rewards and control
|
||||
targets used in the paper.
|
||||
|
||||
To print the actual control targets, run:
|
||||
|
||||
```sh
|
||||
$ python3 -m fusion_tcv.references_main --refs=snowflake
|
||||
```
|
||||
|
||||
Make sure to install the dependencies:
|
||||
`pip install -r fusion_tcv/requirements.txt`.
|
||||
|
||||
## Overview of files
|
||||
|
||||
* `agent.py`: An interface for what a real agent might look like. This mainly
|
||||
exists so the run loop builds.
|
||||
* `combiners.py`: Defined combiners that combine multiple values into one.
|
||||
Useful for creating a scalar reward from a set of reward components.
|
||||
* `environment.py`: Augments the FGE simulator to make it an RL environment
|
||||
(parameter variation, reward computation, etc.).
|
||||
* `experiments.py`: The environment definitions published in the paper.
|
||||
* `fge_octave.py`: An interface for the simulator. Given that FGE isn't open
|
||||
source, this is a just a sketch of how you might use it.
|
||||
* `fge_state.py`: A python interface to the simulator state. Given that FGE
|
||||
isn't open source, this is a sketch and returns fake data.
|
||||
* `named_array.py`: Makes it easier to interact with numpy arrays by name.
|
||||
Useful for referring to parts of the control targets/references.
|
||||
* `noise.py`: Adds action and observation noise.
|
||||
* `param_variation.py`: Defines the physics parameters used by the simulation.
|
||||
* `ref_gen.py`: The tools to generate control targets per time step.
|
||||
* `references.py`: The control targets used in the experiments.
|
||||
* `references_main.py`: Runnable script to output the references to the command
|
||||
line.
|
||||
* `rewards.py`: Computes all of the reward components and combines them together
|
||||
to create a single scalar reward for the environment.
|
||||
* `rewards_used.py`: The actual reward definitions used in the experiments.
|
||||
* `run_loop.py`: An example of interacting with the environment to generate a
|
||||
trajectory to send to the replay buffer. This code is notional as a complete
|
||||
version would require an Agent and a Simulator implementation.
|
||||
* `targets.py`: The reward components that pull data from the control targets
|
||||
and the simulator state for generating rewards. This depends on FGE, so
|
||||
cannot be run.
|
||||
* `tcv_common.py`: Defines the physical attributes of the TCV fusion reactor,
|
||||
and how to interact with it.
|
||||
* `terminations.py`: Defines when the simulation should stop.
|
||||
* `trajectory.py`: Stores the history of the episode.
|
||||
* `transforms.py`: Turns error values into normalized values.
|
||||
41
fusion_tcv/agent.py
Normal file
41
fusion_tcv/agent.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""An agent interface for interacting with the environment."""
|
||||
|
||||
import abc
|
||||
|
||||
import dm_env
|
||||
import numpy as np
|
||||
|
||||
from fusion_tcv import tcv_common
|
||||
|
||||
|
||||
class AbstractAgent(abc.ABC):
|
||||
"""Agent base class."""
|
||||
|
||||
def reset(self):
|
||||
"""Reset to the initial state."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def step(self, timestep: dm_env.TimeStep) -> np.ndarray:
|
||||
"""Return the action given the current observations."""
|
||||
|
||||
|
||||
class ZeroAgent(AbstractAgent):
|
||||
"""An agent that always returns "zero" actions."""
|
||||
|
||||
def step(self, timestep: dm_env.TimeStep) -> np.ndarray:
|
||||
del timestep
|
||||
return np.zeros(tcv_common.action_spec().shape,
|
||||
tcv_common.action_spec().dtype)
|
||||
220
fusion_tcv/combiners.py
Normal file
220
fusion_tcv/combiners.py
Normal file
@@ -0,0 +1,220 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Reward combiners."""
|
||||
|
||||
import abc
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
from scipy import special
|
||||
|
||||
from fusion_tcv import targets
|
||||
|
||||
|
||||
class AbstractCombiner(targets.AbstractTarget):
|
||||
"""Combines a set of rewards, possibly weighted."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, values: List[float],
|
||||
weights: Optional[List[float]] = None) -> List[float]:
|
||||
"""Combines a set of rewards, possibly weighted."""
|
||||
|
||||
@property
|
||||
def outputs(self) -> int:
|
||||
"""All combiners return exactly one value, even if it's NaN."""
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def _clean_values_weights(
|
||||
values: List[float],
|
||||
weights: Optional[List[float]] = None) -> Tuple[List[float], List[float]]:
|
||||
"""Validate the values and weights, and if no weights, return equal."""
|
||||
if weights is None:
|
||||
weights = [1] * len(values)
|
||||
else:
|
||||
if len(values) != len(weights):
|
||||
raise ValueError("Number of weights don't match values. "
|
||||
f"values: {len(values)}, weights: {len(weights)}")
|
||||
for w in weights:
|
||||
if w < 0:
|
||||
raise ValueError(f"Weights must be >=0: {w}")
|
||||
|
||||
new_values_weights = [(v, w) for v, w in zip(values, weights)
|
||||
if not np.isnan(v) and w > 0]
|
||||
return tuple(zip(*new_values_weights)) if new_values_weights else ([], [])
|
||||
|
||||
|
||||
class Mean(AbstractCombiner):
|
||||
"""Take the weighted mean of the values.
|
||||
|
||||
Ignores NaNs and values with weight 0.
|
||||
"""
|
||||
|
||||
def __call__(self, values: List[float],
|
||||
weights: Optional[List[float]] = None) -> List[float]:
|
||||
values, weights = self._clean_values_weights(values, weights)
|
||||
if not values:
|
||||
return [float("nan")]
|
||||
return [sum(r * w for r, w in zip(values, weights)) / sum(weights)]
|
||||
|
||||
|
||||
def _multiply(values, weights, mean):
|
||||
"""Multiplies the values taking care to validate the weights.
|
||||
|
||||
Defines 0^0 = 1 so a reward with no weight is "off" even if the value is 0.
|
||||
|
||||
Args:
|
||||
values: The reward values.
|
||||
weights: The reward weights.
|
||||
mean: If true, divides by the sum of the weights (computes the geometric
|
||||
mean).
|
||||
|
||||
Returns:
|
||||
Product of v^w across the components.
|
||||
"""
|
||||
# If weight and value are both zero, set the value to 1 so that 0^0 = 1.
|
||||
values = [1 if (v == 0 and w == 0) else v for (v, w) in zip(values, weights)]
|
||||
if any(v == 0 for v in values):
|
||||
return [0]
|
||||
den = sum(weights) if mean else 1
|
||||
return [math.exp(sum(np.log(values) * weights) / den)]
|
||||
|
||||
|
||||
class Multiply(AbstractCombiner):
|
||||
"""Combine by multiplying the (weighted) values together.
|
||||
|
||||
This is the same as Geometric mean, but without the n^th root taken at the
|
||||
end. This means doing poorly on several rewards compounds, rather than
|
||||
averages. As such it likely only makes sense after the non-linearities, ie
|
||||
where the values are in the 0-1 range, otherwise it'll cause them to increase.
|
||||
This is even harsher than Min or SmoothMax(-inf).
|
||||
|
||||
Ignores NaNs and values with weight 0.
|
||||
"""
|
||||
|
||||
def __call__(self, values: List[float],
|
||||
weights: Optional[List[float]] = None) -> List[float]:
|
||||
values, weights = self._clean_values_weights(values, weights)
|
||||
if not values:
|
||||
return [float("nan")]
|
||||
return _multiply(values, weights, mean=False)
|
||||
|
||||
|
||||
class GeometricMean(AbstractCombiner):
|
||||
"""Take the weighted geometric mean of the values.
|
||||
|
||||
Pushes values towards 0, so likely only makes sense after the non-linear
|
||||
transforms.
|
||||
|
||||
Ignores NaNs and values with weight 0.
|
||||
"""
|
||||
|
||||
def __call__(self, values: List[float],
|
||||
weights: Optional[List[float]] = None) -> List[float]:
|
||||
values, weights = self._clean_values_weights(values, weights)
|
||||
if not values:
|
||||
return [float("nan")]
|
||||
return _multiply(values, weights, mean=True)
|
||||
|
||||
|
||||
class Min(AbstractCombiner):
|
||||
"""Take the min of the values. Ignores NaNs and values with weight 0."""
|
||||
|
||||
def __call__(self, values: List[float],
|
||||
weights: Optional[List[float]] = None) -> List[float]:
|
||||
values, _ = self._clean_values_weights(values, weights)
|
||||
if not values:
|
||||
return [float("nan")]
|
||||
return [min(values)]
|
||||
|
||||
|
||||
class Max(AbstractCombiner):
|
||||
"""Take the max of the values. Ignores NaNs and values with weight 0."""
|
||||
|
||||
def __call__(self, values: List[float],
|
||||
weights: Optional[List[float]] = None) -> List[float]:
|
||||
values, _ = self._clean_values_weights(values, weights)
|
||||
if not values:
|
||||
return [float("nan")]
|
||||
return [max(values)]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class LNorm(AbstractCombiner):
|
||||
"""Take the l-norm of the values.
|
||||
|
||||
Reasonable norm values (assuming normalized):
|
||||
- 1: avg of the values
|
||||
- 2: euclidean distance metric
|
||||
- inf: max value
|
||||
|
||||
Values in between go between the average and max. As the l-norm goes up, the
|
||||
result gets closer to the max.
|
||||
|
||||
Normalized means dividing by the max possible distance, such that the units
|
||||
still make sense.
|
||||
|
||||
This likely only makes sense before the non-linear transforms. SmoothMax is
|
||||
similar but more flexible and understandable.
|
||||
|
||||
Ignores NaNs and values with weight 0.
|
||||
"""
|
||||
norm: float
|
||||
normalized: bool = True
|
||||
|
||||
def __call__(self, values: List[float],
|
||||
weights: Optional[List[float]] = None) -> List[float]:
|
||||
values, _ = self._clean_values_weights(values, weights)
|
||||
if not values:
|
||||
return [float("nan")]
|
||||
lnorm = np.linalg.norm(values, ord=self.norm)
|
||||
if self.normalized:
|
||||
lnorm /= np.linalg.norm(np.ones(len(values)), ord=self.norm)
|
||||
return [float(lnorm)]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SmoothMax(AbstractCombiner):
|
||||
"""Combines component rewards using a smooth maximum.
|
||||
|
||||
https://en.wikipedia.org/wiki/Smooth_maximum
|
||||
alpha is the exponent for the smooth max.
|
||||
- alpha -> inf: returns the maximum
|
||||
- alpha == 0: returns the weighted average
|
||||
- alpha -> -inf: returns the minimum
|
||||
alpha in between returns values in between.
|
||||
|
||||
Since this varies between min, mean and max, it keeps the existing scale.
|
||||
|
||||
Alpha >= 0 make sense before converting to 0-1, alpha <= 0 make sense after.
|
||||
|
||||
Ignores NaNs and values with weight 0.
|
||||
"""
|
||||
alpha: float
|
||||
|
||||
def __call__(self, values: List[float],
|
||||
weights: Optional[List[float]] = None) -> List[float]:
|
||||
values, weights = self._clean_values_weights(values, weights)
|
||||
if not values:
|
||||
return [float("nan")]
|
||||
if math.isinf(self.alpha):
|
||||
return [max(values) if self.alpha > 0 else min(values)]
|
||||
# Compute weights in a numerically-friendly way.
|
||||
log_soft_weights = [np.log(w) + c * self.alpha
|
||||
for w, c in zip(weights, values)]
|
||||
log_soft_weights -= special.logsumexp(log_soft_weights)
|
||||
soft_weights = np.exp(log_soft_weights)
|
||||
return Mean()(values, soft_weights) # weighted mean
|
||||
157
fusion_tcv/combiners_test.py
Normal file
157
fusion_tcv/combiners_test.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for combiners."""
|
||||
|
||||
import math
|
||||
|
||||
from absl.testing import absltest
|
||||
from fusion_tcv import combiners
|
||||
|
||||
|
||||
NAN = float("nan")
|
||||
|
||||
|
||||
class CombinersTest(absltest.TestCase):
|
||||
|
||||
def assertNan(self, value):
|
||||
self.assertLen(value, 1)
|
||||
self.assertTrue(math.isnan(value[0]))
|
||||
|
||||
def test_errors(self):
|
||||
c = combiners.Mean()
|
||||
with self.assertRaises(ValueError):
|
||||
c([0, 1], [1])
|
||||
with self.assertRaises(ValueError):
|
||||
c([0, 1], [1, 2, 3])
|
||||
with self.assertRaises(ValueError):
|
||||
c([0, 1], [-1, 2])
|
||||
|
||||
def test_mean(self):
|
||||
c = combiners.Mean()
|
||||
self.assertEqual(c([0, 2, 4]), [2])
|
||||
self.assertEqual(c([0, 0.5, 1]), [0.5])
|
||||
self.assertEqual(c([0, 0.5, 1], [0, 0, 1]), [1])
|
||||
self.assertEqual(c([0, 1], [1, 3]), [0.75])
|
||||
self.assertEqual(c([0, NAN], [1, 3]), [0])
|
||||
self.assertNan(c([NAN, NAN], [1, 3]))
|
||||
|
||||
def test_geometric_mean(self):
|
||||
c = combiners.GeometricMean()
|
||||
self.assertEqual(c([0.5, 0]), [0])
|
||||
self.assertEqual(c([0.3]), [0.3])
|
||||
self.assertEqual(c([4, 4]), [4])
|
||||
self.assertEqual(c([0.5, 0.5]), [0.5])
|
||||
self.assertEqual(c([0.5, 0.5], [1, 3]), [0.5])
|
||||
self.assertEqual(c([0.5, 1], [1, 2]), [0.5**(1/3)])
|
||||
self.assertEqual(c([0.5, 1], [2, 1]), [0.5**(2/3)])
|
||||
self.assertEqual(c([0.5, 0], [2, 0]), [0.5])
|
||||
self.assertEqual(c([0.5, 0, 0], [2, 1, 0]), [0])
|
||||
self.assertEqual(c([0.5, NAN, 0], [2, 1, 0]), [0.5])
|
||||
self.assertNan(c([NAN, NAN], [1, 3]))
|
||||
|
||||
def test_multiply(self):
|
||||
c = combiners.Multiply()
|
||||
self.assertEqual(c([0.5, 0]), [0])
|
||||
self.assertEqual(c([0.3]), [0.3])
|
||||
self.assertEqual(c([0.5, 0.5]), [0.25])
|
||||
self.assertEqual(c([0.5, 0.5], [1, 3]), [0.0625])
|
||||
self.assertEqual(c([0.5, 1], [1, 2]), [0.5])
|
||||
self.assertEqual(c([0.5, 1], [2, 1]), [0.25])
|
||||
self.assertEqual(c([0.5, 0], [2, 0]), [0.25])
|
||||
self.assertEqual(c([0.5, 0, 0], [2, 1, 0]), [0])
|
||||
self.assertEqual(c([0.5, NAN], [1, 1]), [0.5])
|
||||
self.assertNan(c([NAN, NAN], [1, 3]))
|
||||
|
||||
def test_min(self):
|
||||
c = combiners.Min()
|
||||
self.assertEqual(c([0, 1]), [0])
|
||||
self.assertEqual(c([0.5, 1]), [0.5])
|
||||
self.assertEqual(c([1, 0.75]), [0.75])
|
||||
self.assertEqual(c([1, 3]), [1])
|
||||
self.assertEqual(c([1, 1, 3], [0, 1, 1]), [1])
|
||||
self.assertEqual(c([NAN, 3]), [3])
|
||||
self.assertNan(c([NAN, NAN], [1, 3]))
|
||||
|
||||
def test_max(self):
|
||||
c = combiners.Max()
|
||||
self.assertEqual(c([0, 1]), [1])
|
||||
self.assertEqual(c([0.5, 1]), [1])
|
||||
self.assertEqual(c([1, 0.75]), [1])
|
||||
self.assertEqual(c([1, 3]), [3])
|
||||
self.assertEqual(c([1, 1, 3], [0, 1, 1]), [3])
|
||||
self.assertEqual(c([NAN, 3]), [3])
|
||||
self.assertNan(c([NAN, NAN], [1, 3]))
|
||||
|
||||
def test_lnorm(self):
|
||||
c = combiners.LNorm(1)
|
||||
self.assertEqual(c([0, 2, 4]), [2])
|
||||
self.assertEqual(c([0, 0.5, 1]), [0.5])
|
||||
self.assertEqual(c([3, 4]), [7 / 2])
|
||||
self.assertEqual(c([0, 2, 4], [1, 1, 0]), [1])
|
||||
self.assertEqual(c([0, 2, NAN]), [1])
|
||||
self.assertNan(c([NAN, NAN], [1, 3]))
|
||||
|
||||
c = combiners.LNorm(1, normalized=False)
|
||||
self.assertEqual(c([0, 2, 4]), [6])
|
||||
self.assertEqual(c([0, 0.5, 1]), [1.5])
|
||||
self.assertEqual(c([3, 4]), [7])
|
||||
|
||||
c = combiners.LNorm(2)
|
||||
self.assertEqual(c([3, 4]), [5 / 2**0.5])
|
||||
|
||||
c = combiners.LNorm(2, normalized=False)
|
||||
self.assertEqual(c([3, 4]), [5])
|
||||
|
||||
c = combiners.LNorm(math.inf)
|
||||
self.assertAlmostEqual(c([3, 4])[0], 4)
|
||||
|
||||
c = combiners.LNorm(math.inf, normalized=False)
|
||||
self.assertAlmostEqual(c([3, 4])[0], 4)
|
||||
|
||||
def test_smoothmax(self):
|
||||
# Max
|
||||
c = combiners.SmoothMax(math.inf)
|
||||
self.assertEqual(c([0, 1]), [1])
|
||||
self.assertEqual(c([0.5, 1]), [1])
|
||||
self.assertEqual(c([1, 0.75]), [1])
|
||||
self.assertEqual(c([1, 3]), [3])
|
||||
|
||||
# Smooth Max
|
||||
c = combiners.SmoothMax(1)
|
||||
self.assertAlmostEqual(c([0, 1])[0], 0.7310585786300049)
|
||||
|
||||
# Mean
|
||||
c = combiners.SmoothMax(0)
|
||||
self.assertEqual(c([0, 2, 4]), [2])
|
||||
self.assertEqual(c([0, 0.5, 1]), [0.5])
|
||||
self.assertEqual(c([0, 0.5, 1], [0, 0, 1]), [1])
|
||||
self.assertEqual(c([0, 2, NAN]), [1])
|
||||
self.assertEqual(c([0, 2, NAN], [0, 1, 1]), [2])
|
||||
self.assertAlmostEqual(c([0, 1], [1, 3])[0], 0.75)
|
||||
self.assertNan(c([NAN, NAN], [1, 3]))
|
||||
|
||||
# Smooth Min
|
||||
c = combiners.SmoothMax(-1)
|
||||
self.assertEqual(c([0, 1])[0], 0.2689414213699951)
|
||||
|
||||
# Min
|
||||
c = combiners.SmoothMax(-math.inf)
|
||||
self.assertEqual(c([0, 1]), [0])
|
||||
self.assertEqual(c([0.5, 1]), [0.5])
|
||||
self.assertEqual(c([1, 0.75]), [0.75])
|
||||
self.assertEqual(c([1, 3]), [1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
160
fusion_tcv/environment.py
Normal file
160
fusion_tcv/environment.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Environment API for FGE simulator."""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import dm_env
|
||||
from dm_env import auto_reset_environment
|
||||
from dm_env import specs
|
||||
import numpy as np
|
||||
|
||||
from fusion_tcv import fge_octave
|
||||
from fusion_tcv import fge_state
|
||||
from fusion_tcv import named_array
|
||||
from fusion_tcv import noise
|
||||
from fusion_tcv import param_variation
|
||||
from fusion_tcv import ref_gen
|
||||
from fusion_tcv import rewards
|
||||
from fusion_tcv import tcv_common
|
||||
from fusion_tcv import terminations
|
||||
|
||||
|
||||
# Re-export as fge_octave should be an implementation detail.
|
||||
ShotCondition = fge_octave.ShotCondition
|
||||
|
||||
|
||||
class Environment(auto_reset_environment.AutoResetEnvironment):
|
||||
"""An environment using the FGE Solver.
|
||||
|
||||
The simulator will return a flux map, which is the environment's hidden state,
|
||||
and some flux measurements, which will be used as observations. The actions
|
||||
represent current levels that are passed to the simulator for the next
|
||||
flux calculation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shot_condition: ShotCondition,
|
||||
reward: rewards.AbstractReward,
|
||||
reference_generator: ref_gen.AbstractReferenceGenerator,
|
||||
max_episode_length: int = 10000,
|
||||
termination: Optional[terminations.Abstract] = None,
|
||||
obs_act_noise: Optional[noise.Noise] = None,
|
||||
power_supply_delays: Optional[Dict[str, List[float]]] = None,
|
||||
param_generator: Optional[param_variation.ParamGenerator] = None):
|
||||
"""Initializes an Environment instance.
|
||||
|
||||
Args:
|
||||
shot_condition: A ShotCondition, specifying shot number and time. This
|
||||
specifies the machine geometry (eg with or without the baffles), and the
|
||||
initial measurements, voltages, current and plasma state.
|
||||
reward: Function to generate a reward term.
|
||||
reference_generator: Generator for the signal to send to references.
|
||||
max_episode_length: Maximum number of steps before episode is truncated
|
||||
and restarted.
|
||||
termination: Decide if the state should be considered a termination.
|
||||
obs_act_noise: Type for setting the observation and action noise. If noise
|
||||
is set to None then the default noise level is used.
|
||||
power_supply_delays: A dict with power supply delays (in seconds), keys
|
||||
are coil type labels ('E', 'F', 'G', 'OH'). `None` means default delays.
|
||||
param_generator: Generator for Liuqe parameter settings. If None then
|
||||
the default settings are used.
|
||||
"""
|
||||
super().__init__()
|
||||
if power_supply_delays is None:
|
||||
power_supply_delays = tcv_common.TCV_ACTION_DELAYS
|
||||
self._simulator = fge_octave.FGESimulatorOctave(
|
||||
shot_condition=shot_condition,
|
||||
power_supply_delays=power_supply_delays)
|
||||
self._reward = reward
|
||||
self._reference_generator = reference_generator
|
||||
self._max_episode_length = max_episode_length
|
||||
self._termination = (termination if termination is not None else
|
||||
terminations.CURRENT_OH_IP)
|
||||
self._noise = (obs_act_noise if obs_act_noise is not None else
|
||||
noise.Noise.use_default_noise())
|
||||
self._param_generator = (param_generator if param_generator is not None else
|
||||
param_variation.ParamGenerator())
|
||||
self._params = None
|
||||
self._step_counter = 0
|
||||
self._last_observation = None
|
||||
|
||||
def observation_spec(self):
|
||||
"""Defines the observations provided by the environment."""
|
||||
return tcv_common.observation_spec()
|
||||
|
||||
def action_spec(self) -> specs.BoundedArray:
|
||||
"""Defines the actions that should be provided to `step`."""
|
||||
return tcv_common.action_spec()
|
||||
|
||||
def _reset(self) -> dm_env.TimeStep:
|
||||
"""Starts a new episode."""
|
||||
self._step_counter = 0
|
||||
self._params = self._param_generator.generate()
|
||||
state = self._simulator.reset(self._params)
|
||||
references = self._reference_generator.reset()
|
||||
zero_act = np.zeros(self.action_spec().shape,
|
||||
dtype=self.action_spec().dtype)
|
||||
self._last_observation = self._extract_observation(
|
||||
state, references, zero_act)
|
||||
return dm_env.restart(self._last_observation)
|
||||
|
||||
def _simulator_voltages_from_voltages(self, voltages):
|
||||
voltage_simulator = np.copy(voltages)
|
||||
if self._params.psu_voltage_offset is not None:
|
||||
for coil, offset in self._params.psu_voltage_offset.items():
|
||||
voltage_simulator[tcv_common.TCV_ACTION_INDICES[coil]] += offset
|
||||
voltage_simulator = np.clip(
|
||||
voltage_simulator,
|
||||
self.action_spec().minimum,
|
||||
self.action_spec().maximum)
|
||||
g_coil = tcv_common.TCV_ACTION_RANGES.index("G")
|
||||
if abs(voltage_simulator[g_coil]) < tcv_common.ENV_G_COIL_DEADBAND:
|
||||
voltage_simulator[g_coil] = 0
|
||||
return voltage_simulator
|
||||
|
||||
def _step(self, action: np.ndarray) -> dm_env.TimeStep:
|
||||
"""Does one step within TCV."""
|
||||
voltages = self._noise.add_action_noise(action)
|
||||
voltage_simulator = self._simulator_voltages_from_voltages(voltages)
|
||||
try:
|
||||
state = self._simulator.step(voltage_simulator)
|
||||
except (fge_state.InvalidSolutionError,
|
||||
fge_state.StopSignalException):
|
||||
return dm_env.termination(
|
||||
self._reward.terminal_reward(), self._last_observation)
|
||||
references = self._reference_generator.step()
|
||||
self._last_observation = self._extract_observation(
|
||||
state, references, action)
|
||||
term = self._termination.terminate(state)
|
||||
if term:
|
||||
return dm_env.termination(
|
||||
self._reward.terminal_reward(), self._last_observation)
|
||||
reward, _ = self._reward.reward(voltages, state, references)
|
||||
self._step_counter += 1
|
||||
if self._step_counter >= self._max_episode_length:
|
||||
return dm_env.truncation(reward, self._last_observation)
|
||||
return dm_env.transition(reward, self._last_observation)
|
||||
|
||||
def _extract_observation(
|
||||
self, state: fge_state.FGEState,
|
||||
references: named_array.NamedArray,
|
||||
action: np.ndarray) -> Dict[str, np.ndarray]:
|
||||
return {
|
||||
"references": references.array,
|
||||
"measurements": self._noise.add_measurement_noise(
|
||||
state.get_observation_vector()),
|
||||
"last_action": action,
|
||||
}
|
||||
72
fusion_tcv/experiments.py
Normal file
72
fusion_tcv/experiments.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The environment definitions used for our experiments."""
|
||||
|
||||
from fusion_tcv import environment
|
||||
from fusion_tcv import references
|
||||
from fusion_tcv import rewards_used
|
||||
|
||||
|
||||
# Used in TCV#70915
|
||||
def fundamental_capability() -> environment.Environment:
|
||||
return environment.Environment(
|
||||
shot_condition=environment.ShotCondition(70166, 0.0872),
|
||||
reward=rewards_used.FUNDAMENTAL_CAPABILITY,
|
||||
reference_generator=references.fundamental_capability(),
|
||||
max_episode_length=10000)
|
||||
|
||||
|
||||
# Used in TCV#70920
|
||||
def elongation() -> environment.Environment:
|
||||
return environment.Environment(
|
||||
shot_condition=environment.ShotCondition(70166, 0.45),
|
||||
reward=rewards_used.ELONGATION,
|
||||
reference_generator=references.elongation(),
|
||||
max_episode_length=5000)
|
||||
|
||||
|
||||
# Used in TCV#70600
|
||||
def iter() -> environment.Environment: # pylint: disable=redefined-builtin
|
||||
return environment.Environment(
|
||||
shot_condition=environment.ShotCondition(70392, 0.0872),
|
||||
reward=rewards_used.ITER,
|
||||
reference_generator=references.iter(),
|
||||
max_episode_length=1000)
|
||||
|
||||
|
||||
# Used in TCV#70457
|
||||
def negative_triangularity() -> environment.Environment:
|
||||
return environment.Environment(
|
||||
shot_condition=environment.ShotCondition(70166, 0.45),
|
||||
reward=rewards_used.NEGATIVE_TRIANGULARITY,
|
||||
reference_generator=references.negative_triangularity(),
|
||||
max_episode_length=5000)
|
||||
|
||||
|
||||
# Used in TCV#70755
|
||||
def snowflake() -> environment.Environment:
|
||||
return environment.Environment(
|
||||
shot_condition=environment.ShotCondition(70166, 0.0872),
|
||||
reward=rewards_used.SNOWFLAKE,
|
||||
reference_generator=references.snowflake(),
|
||||
max_episode_length=10000)
|
||||
|
||||
|
||||
# Used in TCV#69545
|
||||
def droplet() -> environment.Environment:
|
||||
return environment.Environment(
|
||||
shot_condition=environment.ShotCondition(69198, 0.418),
|
||||
reward=rewards_used.DROPLETS,
|
||||
reference_generator=references.droplet(),
|
||||
max_episode_length=2000)
|
||||
79
fusion_tcv/experiments_test.py
Normal file
79
fusion_tcv/experiments_test.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for experiments."""
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
from dm_env import test_utils
|
||||
|
||||
from fusion_tcv import agent
|
||||
from fusion_tcv import experiments
|
||||
from fusion_tcv import run_loop
|
||||
|
||||
|
||||
class FundamentalCapabilityTest(test_utils.EnvironmentTestMixin,
|
||||
absltest.TestCase):
|
||||
|
||||
def make_object_under_test(self):
|
||||
return experiments.fundamental_capability()
|
||||
|
||||
|
||||
class ElongationTest(test_utils.EnvironmentTestMixin, absltest.TestCase):
|
||||
|
||||
def make_object_under_test(self):
|
||||
return experiments.elongation()
|
||||
|
||||
|
||||
class IterTest(test_utils.EnvironmentTestMixin, absltest.TestCase):
|
||||
|
||||
def make_object_under_test(self):
|
||||
return experiments.iter()
|
||||
|
||||
|
||||
class NegativeTriangularityTest(test_utils.EnvironmentTestMixin,
|
||||
absltest.TestCase):
|
||||
|
||||
def make_object_under_test(self):
|
||||
return experiments.negative_triangularity()
|
||||
|
||||
|
||||
class SnowflakeTest(test_utils.EnvironmentTestMixin, absltest.TestCase):
|
||||
|
||||
def make_object_under_test(self):
|
||||
return experiments.snowflake()
|
||||
|
||||
|
||||
class DropletTest(test_utils.EnvironmentTestMixin, absltest.TestCase):
|
||||
|
||||
def make_object_under_test(self):
|
||||
return experiments.droplet()
|
||||
|
||||
|
||||
class ExperimentsTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("fundamental_capability", experiments.fundamental_capability),
|
||||
("elongation", experiments.elongation),
|
||||
("iter", experiments.iter),
|
||||
("negative_triangularity", experiments.negative_triangularity),
|
||||
("snowflake", experiments.snowflake),
|
||||
("droplet", experiments.droplet),
|
||||
)
|
||||
def test_env(self, env_fn):
|
||||
traj = run_loop.run_loop(env_fn(), agent.ZeroAgent(), max_steps=10)
|
||||
self.assertGreaterEqual(len(traj.reward), 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
77
fusion_tcv/fge_octave.py
Normal file
77
fusion_tcv/fge_octave.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Actually interact with FGE via octave."""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
|
||||
from fusion_tcv import fge_state
|
||||
from fusion_tcv import param_variation
|
||||
|
||||
SUBSTEPS = 5
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ShotCondition:
|
||||
"""Represents a shot and time from a real shot."""
|
||||
shot: int
|
||||
time: float
|
||||
|
||||
|
||||
class FGESimulatorOctave:
|
||||
"""Would interact with the FGE solver via Octave.
|
||||
|
||||
Given that FGE isn't open source, this is just a sketch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shot_condition: ShotCondition,
|
||||
power_supply_delays: Dict[str, List[float]]):
|
||||
"""Initialize the simulator.
|
||||
|
||||
Args:
|
||||
shot_condition: A ShotCondition, specifying shot number and time. This
|
||||
specifies the machine geometry (eg with or without the baffles), and the
|
||||
initial measurements, voltages, current and plasma shape.
|
||||
power_supply_delays: A dict with power supply delays (in seconds), keys
|
||||
are coil type labels ('E', 'F', 'G', 'OH'). `None` means default delays.
|
||||
"""
|
||||
del power_supply_delays
|
||||
# Initialize the simulator:
|
||||
# - Use oct2py to load FGE through Octave.
|
||||
# - Load the data for the shot_condition.
|
||||
# - Set up the reactor geometry from the shot_condition.
|
||||
# - Set the timestep to `tcv_common.DT / SUBSTEPS`.
|
||||
# - Set up the solver for singlets or droplets based on the shot_condition.
|
||||
self._num_plasmas = 2 if shot_condition.shot == 69198 else 1
|
||||
# - Set up the power supply, including the limits, initial data, and delays.
|
||||
|
||||
def reset(self, variation: param_variation.Settings) -> fge_state.FGEState:
|
||||
"""Restarts the simulator with parameters."""
|
||||
del variation
|
||||
# Update the simulator with the current physics parameters.
|
||||
# Reset to the initial state from the shot_condition.
|
||||
return fge_state.FGEState(self._num_plasmas) # Filled with the real state.
|
||||
|
||||
def step(self, voltages: np.ndarray) -> fge_state.FGEState:
|
||||
"""Run the simulator with `voltages`, returns the state."""
|
||||
del voltages
|
||||
# for _ in range(SUBSTEPS):
|
||||
# Step the simulator with `voltages`.
|
||||
# raise fge_state.InvalidSolutionError if the solver doesn't converge.
|
||||
# raise fge_state.StopSignalException if an internal termination triggered
|
||||
return fge_state.FGEState(self._num_plasmas) # Filled with the real state.
|
||||
121
fusion_tcv/fge_state.py
Normal file
121
fusion_tcv/fge_state.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A nice python representation of the underlying FGE state."""
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fusion_tcv import shape
|
||||
from fusion_tcv import shapes_known
|
||||
from fusion_tcv import tcv_common
|
||||
|
||||
|
||||
class StopSignalException(Exception): # pylint: disable=g-bad-exception-name
|
||||
"""This is raised if the FGE environment raises the Stop/Alarm signal."""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidSolutionError(RuntimeError):
|
||||
"""This is raised if returned solution is invalid."""
|
||||
pass
|
||||
|
||||
|
||||
class UnhandledOctaveError(Exception):
|
||||
"""This is raised if some Octave code raises an unhandled error."""
|
||||
pass
|
||||
|
||||
|
||||
class FGEState:
|
||||
"""A nice python representation of the underlying FGE State.
|
||||
|
||||
Given that FGE isn't open source, all of these numbers are made up, and only
|
||||
a sketch of what it could look like.
|
||||
"""
|
||||
|
||||
def __init__(self, num_plasmas):
|
||||
self._num_plasmas = num_plasmas
|
||||
|
||||
@property
|
||||
def num_plasmas(self) -> int:
|
||||
return self._num_plasmas # Return 1 for singlet, 2 for droplets.
|
||||
|
||||
@property
|
||||
def rzip_d(self) -> Tuple[List[float], List[float], List[float]]:
|
||||
"""Returns the R, Z, and Ip for each plasma domain."""
|
||||
if self.num_plasmas == 1:
|
||||
return [0.9], [0], [-120000]
|
||||
else:
|
||||
return [0.9, 0.88], [0.4, -0.4], [-60000, -65000]
|
||||
|
||||
def get_coil_currents_by_type(self, coil_type) -> np.ndarray:
|
||||
currents = tcv_common.TCV_ACTION_RANGES.new_random_named_array()
|
||||
return currents[coil_type] * tcv_common.ENV_COIL_MAX_CURRENTS[coil_type] / 5
|
||||
|
||||
def get_lcfs_points(self, domain: int) -> shape.ShapePoints:
|
||||
del domain # Should be plasma domain specific
|
||||
return shapes_known.SHAPE_70166_0872.canonical().points
|
||||
|
||||
def get_observation_vector(self) -> np.ndarray:
|
||||
return tcv_common.TCV_MEASUREMENT_RANGES.new_random_named_array().array
|
||||
|
||||
@property
|
||||
def elongation(self) -> List[float]:
|
||||
return [1.4] * self.num_plasmas
|
||||
|
||||
@property
|
||||
def triangularity(self) -> List[float]:
|
||||
return [0.25] * self.num_plasmas
|
||||
|
||||
@property
|
||||
def radius(self) -> List[float]:
|
||||
return [0.23] * self.num_plasmas
|
||||
|
||||
@property
|
||||
def limit_point_d(self) -> List[shape.Point]:
|
||||
return [shape.Point(tcv_common.INNER_LIMITER_R, 0.2)] * self.num_plasmas
|
||||
|
||||
@property
|
||||
def is_diverted_d(self) -> List[bool]:
|
||||
return [False] * self.num_plasmas
|
||||
|
||||
@property
|
||||
def x_points(self) -> shape.ShapePoints:
|
||||
return []
|
||||
|
||||
@property
|
||||
def flux(self) -> np.ndarray:
|
||||
"""Return the flux at the grid coordinates."""
|
||||
return np.random.random((len(self.z_coordinates), len(self.r_coordinates)))
|
||||
|
||||
@property
|
||||
def magnetic_axis_flux_strength(self) -> float:
|
||||
"""The magnetic flux at the center of the plasma."""
|
||||
return 2
|
||||
|
||||
@property
|
||||
def lcfs_flux_strength(self) -> float:
|
||||
"""The flux at the LCFS."""
|
||||
return 1
|
||||
|
||||
@property
|
||||
def r_coordinates(self) -> np.ndarray:
|
||||
"""The radial coordinates of the simulation."""
|
||||
return np.arange(tcv_common.INNER_LIMITER_R, tcv_common.OUTER_LIMITER_R,
|
||||
tcv_common.LIMITER_WIDTH / 10) # Made up grid resolution.
|
||||
|
||||
@property
|
||||
def z_coordinates(self):
|
||||
"""The vertical coordinates of the simulation."""
|
||||
return np.arange(-0.75, 0.75, 1.5 / 30) # Made up numbers.
|
||||
127
fusion_tcv/named_array.py
Normal file
127
fusion_tcv/named_array.py
Normal file
@@ -0,0 +1,127 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Give names to parts of a numpy array."""
|
||||
|
||||
from typing import Iterable, List, Mapping, MutableMapping, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def lengths_to_ranges(
|
||||
lengths: Mapping[str, int]) -> MutableMapping[str, List[int]]:
|
||||
"""Eg: {a: 2, b: 3} -> {a: [0, 1], b: [2, 3, 4]} ."""
|
||||
ranges = {}
|
||||
start = 0
|
||||
for key, length in lengths.items():
|
||||
ranges[key] = list(range(start, start + length))
|
||||
start += length
|
||||
return ranges
|
||||
|
||||
|
||||
class NamedRanges:
|
||||
"""Given a map of {key: count}, give various views into it."""
|
||||
|
||||
def __init__(self, counts: Mapping[str, int]):
|
||||
self._ranges = lengths_to_ranges(counts)
|
||||
self._size = sum(counts.values())
|
||||
|
||||
def __getitem__(self, name) -> List[int]:
|
||||
return self._ranges[name]
|
||||
|
||||
def __contains__(self, name) -> bool:
|
||||
return name in self._ranges
|
||||
|
||||
def set_range(self, name: str, value: List[int]):
|
||||
"""Overwrite or create a custom range, which may intersect with others."""
|
||||
self._ranges[name] = value
|
||||
|
||||
def range(self, name: str) -> List[int]:
|
||||
return self[name]
|
||||
|
||||
def index(self, name: str) -> int:
|
||||
rng = self[name]
|
||||
if len(rng) != 1:
|
||||
raise ValueError(f"{name} has multiple values")
|
||||
return rng[0]
|
||||
|
||||
def count(self, name: str) -> int:
|
||||
return len(self[name])
|
||||
|
||||
def names(self) -> Iterable[str]:
|
||||
return self._ranges.keys()
|
||||
|
||||
def ranges(self) -> Iterable[Tuple[str, List[int]]]:
|
||||
return self._ranges.items()
|
||||
|
||||
def counts(self) -> Mapping[str, int]:
|
||||
return {k: len(v) for k, v in self._ranges.items()}
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return self._size
|
||||
|
||||
def named_array(self, array: np.ndarray) -> "NamedArray":
|
||||
return NamedArray(array, self)
|
||||
|
||||
def new_named_array(self) -> "NamedArray":
|
||||
return NamedArray(np.zeros((self.size,)), self)
|
||||
|
||||
def new_random_named_array(self) -> "NamedArray":
|
||||
return NamedArray(np.random.uniform(size=(self.size,)), self)
|
||||
|
||||
|
||||
class NamedArray:
|
||||
"""Given a numpy array and a NamedRange, access slices by name."""
|
||||
|
||||
def __init__(self, array: np.ndarray, names: NamedRanges):
|
||||
if array.shape != (names.size,):
|
||||
raise ValueError(f"Wrong sizes: {array.shape} != ({names.size},)")
|
||||
self._array = array
|
||||
self._names = names
|
||||
|
||||
def __getitem__(
|
||||
self, name: Union[str, Tuple[str, Union[int, List[int],
|
||||
slice]]]) -> np.ndarray:
|
||||
"""Return a read-only view into the array by name."""
|
||||
if isinstance(name, str):
|
||||
arr = self._array[self._names[name]]
|
||||
else:
|
||||
name, i = name
|
||||
arr = self._array[np.array(self._names[name])[i]]
|
||||
if not np.isscalar(arr):
|
||||
# Read-only because it's indexed by an array of potentially non-contiguous
|
||||
# indices, which isn't representable as a normal tensor, which forces a
|
||||
# copy and therefore writes don't modify the underlying array as expected.
|
||||
arr.flags.writeable = False
|
||||
return arr
|
||||
|
||||
def __setitem__(
|
||||
self, name: Union[str, Tuple[str, Union[int, List[int], slice]]], value):
|
||||
"""Set one or more values of a range to a value."""
|
||||
if isinstance(name, str):
|
||||
self._array[self._names[name]] = value
|
||||
else:
|
||||
name, i = name
|
||||
self._array[np.array(self._names[name])[i]] = value
|
||||
|
||||
@property
|
||||
def array(self) -> np.ndarray:
|
||||
return self._array
|
||||
|
||||
@property
|
||||
def names(self) -> NamedRanges:
|
||||
return self._names
|
||||
|
||||
def to_dict(self) -> Mapping[str, np.ndarray]:
|
||||
return {k: self[k] for k in self._names.names()}
|
||||
91
fusion_tcv/named_array_test.py
Normal file
91
fusion_tcv/named_array_test.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for named_array."""
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
||||
from fusion_tcv import named_array
|
||||
|
||||
|
||||
class NamedRangesTest(absltest.TestCase):
|
||||
|
||||
def test_lengths_to_ranges(self):
|
||||
self.assertEqual(named_array.lengths_to_ranges({"a": 2, "b": 3}),
|
||||
{"a": [0, 1], "b": [2, 3, 4]})
|
||||
|
||||
def test_named_ranges(self):
|
||||
action_counts = {"E": 8, "F": 8, "OH": 2, "DUMMY": 1, "G": 1}
|
||||
actions = named_array.NamedRanges(action_counts)
|
||||
self.assertEqual(actions.range("E"), list(range(8)))
|
||||
self.assertEqual(actions["F"], list(range(8, 16)))
|
||||
self.assertEqual(actions.range("G"), [19])
|
||||
self.assertEqual(actions.index("G"), 19)
|
||||
with self.assertRaises(ValueError):
|
||||
actions.index("F")
|
||||
for k, v in action_counts.items():
|
||||
self.assertEqual(actions.count(k), v)
|
||||
self.assertEqual(actions.counts(), action_counts)
|
||||
self.assertEqual(list(actions.names()), list(action_counts.keys()))
|
||||
self.assertEqual(actions.size, sum(action_counts.values()))
|
||||
|
||||
refs = actions.new_named_array()
|
||||
self.assertEqual(refs.array.shape, (actions.size,))
|
||||
np.testing.assert_array_equal(refs.array, np.zeros((actions.size,)))
|
||||
|
||||
refs = actions.new_random_named_array()
|
||||
self.assertEqual(refs.array.shape, (actions.size,))
|
||||
self.assertFalse(np.array_equal(refs.array, np.zeros((actions.size,))))
|
||||
|
||||
|
||||
class NamedArrayTest(absltest.TestCase):
|
||||
|
||||
def test_name_array(self):
|
||||
action_counts = {"E": 8, "F": 8, "OH": 2, "DUMMY": 1, "G": 1}
|
||||
actions_ranges = named_array.NamedRanges(action_counts)
|
||||
actions_array = np.arange(actions_ranges.size) + 100
|
||||
actions = named_array.NamedArray(actions_array, actions_ranges)
|
||||
for k in action_counts:
|
||||
self.assertEqual(list(actions[k]), [v + 100 for v in actions_ranges[k]])
|
||||
actions["G"] = -5
|
||||
self.assertEqual(list(actions["G"]), [-5])
|
||||
self.assertEqual(actions_array[19], -5)
|
||||
|
||||
for i in range(action_counts["E"]):
|
||||
actions.names.set_range(f"E_{i}", [i])
|
||||
|
||||
actions["E_3"] = 53
|
||||
self.assertEqual(list(actions["E_1"]), [101])
|
||||
self.assertEqual(list(actions["E_3"]), [53])
|
||||
self.assertEqual(actions_array[3], 53)
|
||||
|
||||
actions["F", 2] = 72
|
||||
self.assertEqual(actions_array[10], 72)
|
||||
|
||||
actions["F", [4, 5]] = 74
|
||||
self.assertEqual(actions_array[12], 74)
|
||||
self.assertEqual(actions_array[13], 74)
|
||||
|
||||
actions["F", 0:2] = 78
|
||||
self.assertEqual(actions_array[8], 78)
|
||||
self.assertEqual(actions_array[9], 78)
|
||||
|
||||
self.assertEqual(list(actions["F"]), [78, 78, 72, 111, 74, 74, 114, 115])
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
actions["F"][5] = 85
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
117
fusion_tcv/noise.py
Normal file
117
fusion_tcv/noise.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Settings for adding noise to the action and measurements."""
|
||||
|
||||
import numpy as np
|
||||
from numpy import random
|
||||
|
||||
from fusion_tcv import tcv_common
|
||||
|
||||
|
||||
class Noise:
|
||||
"""Class for adding noise to the action and measurements."""
|
||||
|
||||
def __init__(self,
|
||||
action_mean=None,
|
||||
action_std=None,
|
||||
measurements_mean=None,
|
||||
measurements_std=None,
|
||||
seed=None):
|
||||
"""Initializes the class.
|
||||
|
||||
Args:
|
||||
action_mean: mean of the Gaussian noise (action bias).
|
||||
action_std: std of the Gaussian action noise.
|
||||
measurements_mean: mean of the Gaussian noise (measurement bias).
|
||||
measurements_std: Dictionary mapping the tcv measurement names to noise.
|
||||
seed: seed for the random number generator. If none seed is unset.
|
||||
"""
|
||||
# Check all of the shapes are present and correct.
|
||||
assert action_std.shape == (tcv_common.NUM_ACTIONS,)
|
||||
assert action_mean.shape == (tcv_common.NUM_ACTIONS,)
|
||||
for name, num in tcv_common.TCV_MEASUREMENTS.items():
|
||||
assert name in measurements_std
|
||||
assert measurements_mean[name].shape == (num,)
|
||||
assert measurements_std[name].shape == (num,)
|
||||
self._action_mean = action_mean
|
||||
self._action_std = action_std
|
||||
self._meas_mean = measurements_mean
|
||||
self._meas_std = measurements_std
|
||||
self._meas_mean_vec = tcv_common.dict_to_measurement(self._meas_mean)
|
||||
self._meas_std_vec = tcv_common.dict_to_measurement(self._meas_std)
|
||||
self._gen = random.RandomState(seed)
|
||||
|
||||
@classmethod
|
||||
def use_zero_noise(cls):
|
||||
no_noise_mean = dict()
|
||||
no_noise_std = dict()
|
||||
for name, num in tcv_common.TCV_MEASUREMENTS.items():
|
||||
no_noise_mean[name] = np.zeros((num,))
|
||||
no_noise_std[name] = np.zeros((num,))
|
||||
return cls(
|
||||
action_mean=np.zeros((tcv_common.NUM_ACTIONS)),
|
||||
action_std=np.zeros((tcv_common.NUM_ACTIONS)),
|
||||
measurements_mean=no_noise_mean,
|
||||
measurements_std=no_noise_std)
|
||||
|
||||
@classmethod
|
||||
def use_default_noise(cls, scale=1):
|
||||
"""Returns the default observation noise parameters."""
|
||||
|
||||
# There is no noise added to the actions, because the noise should be added
|
||||
# to the action after/as part of the power supply model as opposed to the
|
||||
# input to the power supply model.
|
||||
action_noise_mean = np.zeros((tcv_common.NUM_ACTIONS))
|
||||
action_noise_std = np.zeros((tcv_common.NUM_ACTIONS))
|
||||
|
||||
meas_noise_mean = dict()
|
||||
for key, l in tcv_common.TCV_MEASUREMENTS.items():
|
||||
meas_noise_mean[key] = np.zeros((l,))
|
||||
meas_noise_std = dict(
|
||||
clint_vloop=np.array([0]),
|
||||
clint_rvloop=np.array([scale * 1e-4] * 37),
|
||||
bm=np.array([scale * 1e-4] * 38),
|
||||
IE=np.array([scale * 20] * 8),
|
||||
IF=np.array([scale * 5] * 8),
|
||||
IOH=np.array([scale * 20] *2),
|
||||
Bdot=np.array([scale * 0.05] * 20),
|
||||
DIOH=np.array([scale * 30]),
|
||||
FIR_FRINGE=np.array([0]),
|
||||
IG=np.array([scale * 2.5]),
|
||||
ONEMM=np.array([0]),
|
||||
vloop=np.array([scale * 0.3]),
|
||||
IPHI=np.array([0]),
|
||||
)
|
||||
return cls(
|
||||
action_mean=action_noise_mean,
|
||||
action_std=action_noise_std,
|
||||
measurements_mean=meas_noise_mean,
|
||||
measurements_std=meas_noise_std)
|
||||
|
||||
def add_action_noise(self, action):
|
||||
errs = self._gen.normal(size=action.shape,
|
||||
loc=self._action_mean,
|
||||
scale=self._action_std)
|
||||
return action + errs
|
||||
|
||||
def add_measurement_noise(self, measurement_vec):
|
||||
errs = self._gen.normal(size=measurement_vec.shape,
|
||||
loc=self._meas_mean_vec,
|
||||
scale=self._meas_std_vec)
|
||||
# Make the IOH measurements consistent. The "real" measurements are IOH
|
||||
# and DIOH, so use those.
|
||||
errs = tcv_common.measurements_to_dict(errs)
|
||||
errs["IOH"][1] = errs["IOH"][0] + errs["DIOH"][0]
|
||||
errs = tcv_common.dict_to_measurement(errs)
|
||||
return measurement_vec + errs
|
||||
126
fusion_tcv/param_variation.py
Normal file
126
fusion_tcv/param_variation.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tools for varying parameters from simulation to simulation."""
|
||||
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
|
||||
from fusion_tcv import tcv_common
|
||||
|
||||
# Pylint does not like variable names like `qA`.
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
RP_DEFAULT = 5e-6
|
||||
LP_DEFAULT = 2.05e-6
|
||||
BP_DEFAULT = 0.25
|
||||
QA_DEFAULT = 1.3
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Settings:
|
||||
"""Settings to modify solver/plasma model."""
|
||||
# Inverse of the resistivity.
|
||||
# Plasma circuit equation is roughly
|
||||
# k * dIoh/dt = L * dIp/dt + R * I = Vloop
|
||||
# where R is roughly (1 / signeo) or rp.
|
||||
# Value is multiplier on the default value.
|
||||
# This parameter does not apply to the OhmTor diffusion.
|
||||
signeo: Tuple[float, float] = (1, 1)
|
||||
# Rp Plasma resistivity. The value is an absolute value.
|
||||
rp: float = RP_DEFAULT
|
||||
# Plasma self-inductance. The value is an absolute value.
|
||||
lp: float = LP_DEFAULT
|
||||
# Proportional to the plasma pressure. The value is an absolute value.
|
||||
bp: float = BP_DEFAULT
|
||||
# Plasma current profile. Value is absolute.
|
||||
qA: float = QA_DEFAULT
|
||||
# Initial OH coil current. Applied to both coils.
|
||||
ioh: Optional[float] = None
|
||||
# The voltage offsets for the various coils.
|
||||
psu_voltage_offset: Optional[Dict[str, float]] = None
|
||||
|
||||
def _psu_voltage_offset_string(self) -> str:
|
||||
"""Return a short-ish, readable string of the psu voltage offsets."""
|
||||
if not self.psu_voltage_offset:
|
||||
return "None"
|
||||
if len(self.psu_voltage_offset) < 8: # Only a few, output individually.
|
||||
return ", ".join(
|
||||
f"{coil.replace('_00', '')}: {offset:.0f}"
|
||||
for coil, offset in self.psu_voltage_offset.items())
|
||||
# Otherwise, too long, so output in groups.
|
||||
groups = []
|
||||
for coil, action_range in tcv_common.TCV_ACTION_RANGES.ranges():
|
||||
offsets = [self.psu_voltage_offset.get(tcv_common.TCV_ACTIONS[i], 0)
|
||||
for i in action_range]
|
||||
if any(offsets):
|
||||
groups.append(f"{coil}: " + ",".join(f"{offset:.0f}"
|
||||
for offset in offsets))
|
||||
return ", ".join(groups)
|
||||
|
||||
|
||||
class ParamGenerator:
|
||||
"""Varies parameters using uniform/loguniform distributions.
|
||||
|
||||
Absolute parameters are varied using uniform distributions while scaling
|
||||
parameters use a loguniform distribution.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
rp_bounds: Optional[Tuple[float, float]] = None,
|
||||
lp_bounds: Optional[Tuple[float, float]] = None,
|
||||
qA_bounds: Optional[Tuple[float, float]] = None,
|
||||
bp_bounds: Optional[Tuple[float, float]] = None,
|
||||
rp_mean: float = RP_DEFAULT,
|
||||
lp_mean: float = LP_DEFAULT,
|
||||
bp_mean: float = BP_DEFAULT,
|
||||
qA_mean: float = QA_DEFAULT,
|
||||
ioh_bounds: Optional[Tuple[float, float]] = None,
|
||||
psu_voltage_offset_bounds: Optional[
|
||||
Dict[str, Tuple[float, float]]] = None):
|
||||
# Do not allow Signeo variation as this does not work with OhmTor current
|
||||
# diffusion.
|
||||
no_scaling = (1, 1)
|
||||
self._rp_bounds = rp_bounds if rp_bounds else no_scaling
|
||||
self._lp_bounds = lp_bounds if lp_bounds else no_scaling
|
||||
self._bp_bounds = bp_bounds if bp_bounds else no_scaling
|
||||
self._qA_bounds = qA_bounds if qA_bounds else no_scaling
|
||||
self._rp_mean = rp_mean
|
||||
self._lp_mean = lp_mean
|
||||
self._bp_mean = bp_mean
|
||||
self._qA_mean = qA_mean
|
||||
self._ioh_bounds = ioh_bounds
|
||||
self._psu_voltage_offset_bounds = psu_voltage_offset_bounds
|
||||
|
||||
def generate(self) -> Settings:
|
||||
return Settings(
|
||||
signeo=(1, 1),
|
||||
rp=loguniform_rv(*self._rp_bounds) * self._rp_mean,
|
||||
lp=loguniform_rv(*self._lp_bounds) * self._lp_mean,
|
||||
bp=loguniform_rv(*self._bp_bounds) * self._bp_mean,
|
||||
qA=loguniform_rv(*self._qA_bounds) * self._qA_mean,
|
||||
ioh=np.random.uniform(*self._ioh_bounds) if self._ioh_bounds else None,
|
||||
psu_voltage_offset=(
|
||||
{coil: np.random.uniform(*bounds)
|
||||
for coil, bounds in self._psu_voltage_offset_bounds.items()}
|
||||
if self._psu_voltage_offset_bounds else None))
|
||||
|
||||
|
||||
def loguniform_rv(lower, upper):
|
||||
"""Generate loguniform random variable between min and max."""
|
||||
if lower == upper:
|
||||
return lower
|
||||
assert lower < upper
|
||||
return np.exp(np.random.uniform(np.log(lower), np.log(upper)))
|
||||
202
fusion_tcv/ref_gen.py
Normal file
202
fusion_tcv/ref_gen.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Generators for References vector."""
|
||||
|
||||
import abc
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
|
||||
import dataclasses
|
||||
|
||||
from fusion_tcv import named_array
|
||||
from fusion_tcv import shape as shape_lib
|
||||
from fusion_tcv import tcv_common
|
||||
|
||||
|
||||
class AbstractReferenceGenerator(abc.ABC):
|
||||
"""Abstract class for generating the reference signal."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self) -> named_array.NamedArray:
|
||||
"""Resets the class for a new episode and returns the first reference."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def step(self) -> named_array.NamedArray:
|
||||
"""Returns the reference signal."""
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class LinearTransition:
|
||||
reference: named_array.NamedArray # Reference at which to end the transition.
|
||||
transition_steps: int # Number of intermediate steps between the shapes.
|
||||
steady_steps: int # Number of steps in the steady state.
|
||||
|
||||
|
||||
class LinearTransitionReferenceGenerator(AbstractReferenceGenerator):
|
||||
"""A base class for generating references that are a series of transitions."""
|
||||
|
||||
def __init__(self, start_offset: int = 0):
|
||||
self._last_ref = None
|
||||
self._reset_counters()
|
||||
self._start_offset = start_offset
|
||||
|
||||
@abc.abstractmethod
|
||||
def _next_transition(self) -> LinearTransition:
|
||||
"""Override this in the subclass."""
|
||||
|
||||
def reset(self) -> named_array.NamedArray:
|
||||
self._last_ref = None
|
||||
self._reset_counters()
|
||||
for _ in range(self._start_offset):
|
||||
self.step()
|
||||
return self.step()
|
||||
|
||||
def _reset_counters(self):
|
||||
self._steady_step = 0
|
||||
self._transition_step = 0
|
||||
self._transition = None
|
||||
|
||||
def step(self) -> named_array.NamedArray:
|
||||
if (self._transition is None or
|
||||
self._steady_step == self._transition.steady_steps):
|
||||
if self._transition is not None:
|
||||
self._last_ref = self._transition.reference
|
||||
self._reset_counters()
|
||||
self._transition = self._next_transition()
|
||||
# Ensure at least one steady step in middle transitions.
|
||||
# If we would like this to not have to be true, we need to change the
|
||||
# logic below which assumes there is at least one step in the steady
|
||||
# phase.
|
||||
assert self._transition.steady_steps > 0
|
||||
|
||||
assert self._transition is not None # to make pytype happy
|
||||
transition_steps = self._transition.transition_steps
|
||||
if self._last_ref is None: # No transition at beginning of episode.
|
||||
transition_steps = 0
|
||||
|
||||
if self._transition_step < transition_steps: # In transition phase.
|
||||
self._transition_step += 1
|
||||
a = self._transition_step / (self._transition.transition_steps + 1) # pytype: disable=attribute-error
|
||||
return self._last_ref.names.named_array(
|
||||
self._last_ref.array * (1 - a) + self._transition.reference.array * a) # pytype: disable=attribute-error
|
||||
else: # In steady phase.
|
||||
self._steady_step += 1
|
||||
return copy.deepcopy(self._transition.reference)
|
||||
|
||||
|
||||
class FixedReferenceGenerator(LinearTransitionReferenceGenerator):
|
||||
"""Generates linear transitions from a fixed set of references."""
|
||||
|
||||
def __init__(self, transitions: List[LinearTransition],
|
||||
start_offset: int = 0):
|
||||
self._transitions = transitions
|
||||
self._current_transition = 0
|
||||
super().__init__(start_offset=start_offset)
|
||||
|
||||
def reset(self) -> named_array.NamedArray:
|
||||
self._current_transition = 0
|
||||
return super().reset()
|
||||
|
||||
def _next_transition(self) -> LinearTransition:
|
||||
if self._current_transition == len(self._transitions):
|
||||
# Have gone through all of the transitions. Return the final reference
|
||||
# for a very long time.
|
||||
return LinearTransition(steady_steps=50000, transition_steps=0,
|
||||
reference=self._transitions[-1].reference)
|
||||
self._current_transition += 1
|
||||
return copy.deepcopy(self._transitions[self._current_transition - 1])
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TimedTransition:
|
||||
steady_steps: int # Number of steps to hold the shape.
|
||||
transition_steps: int # Number of steps to transition.
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ParametrizedShapeTimedTarget:
|
||||
"""RZIP condition with a timestep attached."""
|
||||
shape: shape_lib.Shape
|
||||
timing: TimedTransition
|
||||
|
||||
|
||||
class PresetShapePointsReferenceGenerator(FixedReferenceGenerator):
|
||||
"""Generates a fixed set of shape points."""
|
||||
|
||||
def __init__(
|
||||
self, targets: List[ParametrizedShapeTimedTarget], start_offset: int = 0):
|
||||
if targets[0].timing.transition_steps != 0:
|
||||
raise ValueError("Invalid first timing, transition must be 0, not "
|
||||
f"{targets[0].timing.transition_steps}")
|
||||
transitions = []
|
||||
for target in targets:
|
||||
transitions.append(LinearTransition(
|
||||
steady_steps=target.timing.steady_steps,
|
||||
transition_steps=target.timing.transition_steps,
|
||||
reference=target.shape.canonical().gen_references()))
|
||||
super().__init__(transitions, start_offset=start_offset)
|
||||
|
||||
|
||||
class ShapeFromShot(PresetShapePointsReferenceGenerator):
|
||||
"""Generate shapes from EPFL references."""
|
||||
|
||||
def __init__(
|
||||
self, time_slices: List[shape_lib.ReferenceTimeSlice],
|
||||
start: Optional[float] = None):
|
||||
"""Given a series of time slices, start from time_slice.time==start."""
|
||||
if start is None:
|
||||
start = time_slices[0].time
|
||||
dt = 1e-4
|
||||
targets = []
|
||||
time_slices = shape_lib.canonicalize_reference_series(time_slices)
|
||||
prev = None
|
||||
for i, ref in enumerate(time_slices):
|
||||
assert prev is None or prev.hold < ref.time
|
||||
if ref.time < start:
|
||||
continue
|
||||
if prev is None and start != ref.time:
|
||||
raise ValueError("start must be one of the time slice times.")
|
||||
|
||||
steady = (max(1, int((ref.hold - ref.time) / dt))
|
||||
if i < len(time_slices) - 1 else 100000)
|
||||
transition = (0 if prev is None else
|
||||
(int((ref.time - prev.time) / dt) -
|
||||
max(1, int((prev.hold - prev.time) / dt))))
|
||||
|
||||
targets.append(ParametrizedShapeTimedTarget(
|
||||
shape=ref.shape,
|
||||
timing=TimedTransition(
|
||||
steady_steps=steady, transition_steps=transition)))
|
||||
prev = ref
|
||||
|
||||
assert targets
|
||||
super().__init__(targets)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RZIpTarget:
|
||||
r: float
|
||||
z: float
|
||||
ip: float
|
||||
|
||||
|
||||
def make_symmetric_multidomain_rzip_reference(
|
||||
target: RZIpTarget) -> named_array.NamedArray:
|
||||
"""Generate multi-domain rzip references."""
|
||||
refs = tcv_common.REF_RANGES.new_named_array()
|
||||
refs["R"] = (target.r, target.r)
|
||||
refs["Z"] = (target.z, -target.z)
|
||||
refs["Ip"] = (target.ip, target.ip)
|
||||
return refs
|
||||
|
||||
1589
fusion_tcv/references.py
Normal file
1589
fusion_tcv/references.py
Normal file
File diff suppressed because it is too large
Load Diff
58
fusion_tcv/references_main.py
Normal file
58
fusion_tcv/references_main.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Print the values of references."""
|
||||
|
||||
from typing import Sequence
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from fusion_tcv import named_array
|
||||
from fusion_tcv import references
|
||||
from fusion_tcv import tcv_common
|
||||
|
||||
|
||||
_refs = flags.DEFINE_enum("refs", None, references.REFERENCES.keys(),
|
||||
"Which references to print")
|
||||
_count = flags.DEFINE_integer("count", 100, "How many timesteps to print.")
|
||||
_freq = flags.DEFINE_integer("freq", 1, "Print only every so often.")
|
||||
_fields = flags.DEFINE_multi_enum(
|
||||
"field", None, tcv_common.REF_RANGES.names(),
|
||||
"Which reference fields to print, default of all.")
|
||||
flags.mark_flag_as_required("refs")
|
||||
|
||||
|
||||
def print_ref(step: int, ref: named_array.NamedArray):
|
||||
print(f"Step: {step}")
|
||||
for k in (_fields.value or ref.names.names()):
|
||||
print(f" {k}: [{', '.join(f'{v:.3f}' for v in ref[k])}]")
|
||||
|
||||
|
||||
def main(argv: Sequence[str]) -> None:
|
||||
if len(argv) > 1:
|
||||
raise app.UsageError("Too many command-line arguments.")
|
||||
if _freq.value <= 0:
|
||||
raise app.UsageError("`freq` must be >0.")
|
||||
|
||||
ref = references.REFERENCES[_refs.value]()
|
||||
print_ref(0, ref.reset())
|
||||
for i in range(1, _count.value + 1):
|
||||
for _ in range(_freq.value - 1):
|
||||
ref.step()
|
||||
print_ref(i * _freq.value, ref.step())
|
||||
print(f"Stopped after {_count.value * _freq.value} steps.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(main)
|
||||
4
fusion_tcv/requirements.txt
Normal file
4
fusion_tcv/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
absl-py>=0.13.0
|
||||
dm-env>=1.5
|
||||
numpy>=1.21.0
|
||||
scipy>=1.7.0
|
||||
185
fusion_tcv/rewards.py
Normal file
185
fusion_tcv/rewards.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Reward function for the fusion environment."""
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import functools
|
||||
from typing import Callable, Dict, List, Optional, Text, Tuple, Union
|
||||
|
||||
from absl import logging
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
|
||||
from fusion_tcv import combiners
|
||||
from fusion_tcv import fge_state
|
||||
from fusion_tcv import named_array
|
||||
from fusion_tcv import targets as targets_lib
|
||||
from fusion_tcv import transforms
|
||||
|
||||
|
||||
class AbstractMeasure(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, targets: List[targets_lib.Target]) -> List[float]:
|
||||
"""Returns a list of error measures."""
|
||||
|
||||
|
||||
class AbsDist(AbstractMeasure):
|
||||
"""Return the absolute distance between the actual and target."""
|
||||
|
||||
@staticmethod
|
||||
def __call__(targets: List[targets_lib.Target]) -> List[float]:
|
||||
return [abs(t.actual - t.target) for t in targets]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class MeasureDetails:
|
||||
min: float
|
||||
mean: float
|
||||
max: float
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RewardDetails:
|
||||
reward: float # 0-1 reward value.
|
||||
weighted: float # Should sum to < 0-1.
|
||||
weight: float
|
||||
measure: Optional[MeasureDetails] = None
|
||||
|
||||
|
||||
class AbstractReward(abc.ABC):
|
||||
"""Abstract reward class."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def reward(
|
||||
self,
|
||||
voltages: np.ndarray,
|
||||
state: fge_state.FGEState,
|
||||
references: named_array.NamedArray,
|
||||
) -> Tuple[float, Dict[Text, List[RewardDetails]]]:
|
||||
"""Returns the reward and log dict as a function of the penalty term."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def terminal_reward(self) -> float:
|
||||
"""Returns the reward if the simulator crashed."""
|
||||
|
||||
|
||||
WeightFn = Callable[[named_array.NamedArray], float]
|
||||
WeightOrFn = Union[float, WeightFn]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Component:
|
||||
target: targets_lib.AbstractTarget
|
||||
transforms: List[transforms.AbstractTransform]
|
||||
measure: AbstractMeasure = dataclasses.field(default_factory=AbsDist)
|
||||
weight: Union[WeightOrFn, List[WeightOrFn]] = 1
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class Reward(AbstractReward):
|
||||
"""Combines a bunch of reward components into a single reward.
|
||||
|
||||
The component parts are applied in the order: target, measure, transform.
|
||||
- Targets represent some error value as one or more pair of values
|
||||
(target, actual), usually with some meaningful physical unit (eg distance,
|
||||
volts, etc).
|
||||
- Measures combine the (target, actual) into a single float, for example
|
||||
absolute distance, for each error value.
|
||||
- Transforms can make arbitrary conversions, but one of them must change from
|
||||
the arbitrary (often meaningful) scale to a reward in the 0-1 range.
|
||||
- Combiners are a special type of transform that reduces a vector of values
|
||||
down to a single value. The combiner can be skipped if the target only
|
||||
outputs a single value, or if you want a vector of outputs for the final
|
||||
combiner.
|
||||
- The component weights are passed to the final combiner, and must match the
|
||||
number of outputs for that component.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
components: List[Component],
|
||||
combiner: combiners.AbstractCombiner,
|
||||
terminal_reward: float = -5,
|
||||
reward_scale: float = 0.01):
|
||||
self._components = components
|
||||
self._combiner = combiner
|
||||
self._terminal_reward = terminal_reward
|
||||
self._reward_scale = reward_scale
|
||||
|
||||
self._weights = []
|
||||
component_count = collections.Counter()
|
||||
for component in self._components:
|
||||
num_outputs = component.target.outputs
|
||||
for transform in component.transforms:
|
||||
if transform.outputs is not None:
|
||||
num_outputs = transform.outputs
|
||||
if not isinstance(component.weight, list):
|
||||
component.weight = [component.weight]
|
||||
if len(component.weight) != num_outputs:
|
||||
name = component.name or component.target.name
|
||||
raise ValueError(f"Wrong number of weights for '{name}': got:"
|
||||
f" {len(component.weight)}, expected: {num_outputs}")
|
||||
self._weights.extend(component.weight)
|
||||
|
||||
def terminal_reward(self) -> float:
|
||||
return self._terminal_reward * self._reward_scale
|
||||
|
||||
def reward(
|
||||
self,
|
||||
voltages: np.ndarray,
|
||||
state: fge_state.FGEState,
|
||||
references: named_array.NamedArray,
|
||||
) -> Tuple[float, Dict[Text, List[RewardDetails]]]:
|
||||
values = []
|
||||
weights = [weight(references) if callable(weight) else weight
|
||||
for weight in self._weights]
|
||||
reward_dict = collections.defaultdict(list)
|
||||
for component in self._components:
|
||||
name = component.name or component.target.name
|
||||
num_outputs = len(component.weight)
|
||||
component_weights = weights[len(values):(len(values) + num_outputs)]
|
||||
try:
|
||||
target = component.target(voltages, state, references)
|
||||
except targets_lib.TargetError:
|
||||
logging.exception("Target failed.")
|
||||
# Failed turns into minimum reward.
|
||||
measure = [987654321] * num_outputs
|
||||
transformed = [0] * num_outputs
|
||||
else:
|
||||
measure = component.measure(target)
|
||||
transformed = functools.reduce(
|
||||
(lambda e, fn: fn(e)), component.transforms, measure)
|
||||
assert len(transformed) == num_outputs
|
||||
for v in transformed:
|
||||
if not np.isnan(v) and not 0 <= v <= 1:
|
||||
raise ValueError(f"The transformed value in {name} is invalid: {v}")
|
||||
values.extend(transformed)
|
||||
for weight, value in zip(component_weights, transformed):
|
||||
measure = [m for m in measure if not np.isnan(m)] or [float("nan")]
|
||||
reward_dict[name].append(RewardDetails(
|
||||
value, weight * value * self._reward_scale,
|
||||
weight if not np.isnan(value) else 0,
|
||||
MeasureDetails(
|
||||
min(measure), sum(measure) / len(measure), max(measure))))
|
||||
|
||||
sum_weights = sum(sum(d.weight for d in detail)
|
||||
for detail in reward_dict.values())
|
||||
for reward_details in reward_dict.values():
|
||||
for detail in reward_details:
|
||||
detail.weighted /= sum_weights
|
||||
|
||||
final_combined = self._combiner(values, weights)
|
||||
assert len(final_combined) == 1
|
||||
return final_combined[0] * self._reward_scale, reward_dict
|
||||
253
fusion_tcv/rewards_used.py
Normal file
253
fusion_tcv/rewards_used.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""The rewards used in our experiments."""
|
||||
|
||||
from fusion_tcv import combiners
|
||||
from fusion_tcv import rewards
|
||||
from fusion_tcv import targets
|
||||
from fusion_tcv import transforms
|
||||
|
||||
|
||||
# Used in TCV#70915
|
||||
FUNDAMENTAL_CAPABILITY = rewards.Reward([
|
||||
rewards.Component(
|
||||
target=targets.ShapeLCFSDistance(),
|
||||
transforms=[transforms.SoftPlus(good=0.005, bad=0.05),
|
||||
combiners.SmoothMax(-1)]),
|
||||
rewards.Component(
|
||||
target=targets.XPointFar(),
|
||||
transforms=[transforms.Sigmoid(good=0.3, bad=0.1),
|
||||
combiners.SmoothMax(-5)]),
|
||||
rewards.Component(
|
||||
target=targets.LimitPoint(),
|
||||
transforms=[transforms.Sigmoid(bad=0.2, good=0.1)]),
|
||||
rewards.Component(
|
||||
target=targets.XPointNormalizedFlux(num_points=1),
|
||||
transforms=[transforms.SoftPlus(bad=0.08)]),
|
||||
rewards.Component(
|
||||
target=targets.XPointDistance(num_points=1),
|
||||
transforms=[transforms.Sigmoid(good=0.01, bad=0.15)]),
|
||||
rewards.Component(
|
||||
target=targets.XPointFluxGradient(num_points=1),
|
||||
transforms=[transforms.SoftPlus(bad=3)],
|
||||
weight=0.5),
|
||||
rewards.Component(
|
||||
target=targets.Ip(),
|
||||
transforms=[transforms.SoftPlus(good=500, bad=20000)]),
|
||||
rewards.Component(
|
||||
target=targets.OHCurrentsClose(),
|
||||
transforms=[transforms.SoftPlus(good=50, bad=1050)]),
|
||||
], combiners.SmoothMax(-0.5))
|
||||
|
||||
|
||||
# Used in TCV#70920
|
||||
ELONGATION = rewards.Reward([
|
||||
rewards.Component(
|
||||
target=targets.ShapeLCFSDistance(),
|
||||
transforms=[transforms.SoftPlus(good=0.003, bad=0.03),
|
||||
combiners.SmoothMax(-1)],
|
||||
weight=3),
|
||||
rewards.Component(
|
||||
target=targets.ShapeRadius(),
|
||||
transforms=[transforms.SoftPlus(good=0.002, bad=0.02)]),
|
||||
rewards.Component(
|
||||
target=targets.ShapeElongation(),
|
||||
transforms=[transforms.SoftPlus(good=0.005, bad=0.2)]),
|
||||
rewards.Component(
|
||||
target=targets.ShapeTriangularity(),
|
||||
transforms=[transforms.SoftPlus(good=0.005, bad=0.2)]),
|
||||
rewards.Component(
|
||||
target=targets.XPointCount(),
|
||||
transforms=[transforms.Equal()]),
|
||||
rewards.Component(
|
||||
target=targets.LimitPoint(), # Stay away from the top/baffles.
|
||||
transforms=[transforms.Sigmoid(bad=0.3, good=0.2)]),
|
||||
rewards.Component(
|
||||
target=targets.Ip(),
|
||||
transforms=[transforms.SoftPlus(good=500, bad=30000)]),
|
||||
rewards.Component(
|
||||
target=targets.VoltageOOB(),
|
||||
transforms=[combiners.Mean(), transforms.SoftPlus(bad=1)]),
|
||||
rewards.Component(
|
||||
target=targets.OHCurrentsClose(),
|
||||
transforms=[transforms.ClippedLinear(good=50, bad=1050)]),
|
||||
rewards.Component(
|
||||
name="CurrentsFarFromZero",
|
||||
target=targets.EFCurrents(),
|
||||
transforms=[transforms.Abs(),
|
||||
transforms.SoftPlus(good=100, bad=50),
|
||||
combiners.GeometricMean()]),
|
||||
], combiner=combiners.SmoothMax(-5))
|
||||
|
||||
|
||||
# Used in TCV#70600
|
||||
ITER = rewards.Reward([
|
||||
rewards.Component(
|
||||
target=targets.ShapeLCFSDistance(),
|
||||
transforms=[transforms.SoftPlus(good=0.005, bad=0.05),
|
||||
combiners.SmoothMax(-1)],
|
||||
weight=3),
|
||||
rewards.Component(
|
||||
target=targets.Diverted(),
|
||||
transforms=[transforms.Equal()]),
|
||||
rewards.Component(
|
||||
target=targets.XPointNormalizedFlux(num_points=2),
|
||||
transforms=[transforms.SoftPlus(bad=0.08)],
|
||||
weight=[1] * 2),
|
||||
rewards.Component(
|
||||
target=targets.XPointDistance(num_points=2),
|
||||
transforms=[transforms.Sigmoid(good=0.01, bad=0.15)],
|
||||
weight=[0.5] * 2),
|
||||
rewards.Component(
|
||||
target=targets.XPointFluxGradient(num_points=2),
|
||||
transforms=[transforms.SoftPlus(bad=3)],
|
||||
weight=[0.5] * 2),
|
||||
rewards.Component(
|
||||
target=targets.LegsNormalizedFlux(),
|
||||
transforms=[transforms.Sigmoid(good=0.1, bad=0.3),
|
||||
combiners.SmoothMax(-5)],
|
||||
weight=2),
|
||||
rewards.Component(
|
||||
target=targets.Ip(),
|
||||
transforms=[transforms.SoftPlus(good=500, bad=20000)],
|
||||
weight=2),
|
||||
rewards.Component(
|
||||
target=targets.VoltageOOB(),
|
||||
transforms=[combiners.Mean(), transforms.SoftPlus(bad=1)]),
|
||||
rewards.Component(
|
||||
target=targets.OHCurrentsClose(),
|
||||
transforms=[transforms.ClippedLinear(good=50, bad=1050)]),
|
||||
rewards.Component(
|
||||
name="CurrentsFarFromZero",
|
||||
target=targets.EFCurrents(),
|
||||
transforms=[transforms.Abs(),
|
||||
transforms.SoftPlus(good=100, bad=50),
|
||||
combiners.GeometricMean()]),
|
||||
], combiner=combiners.SmoothMax(-5))
|
||||
|
||||
|
||||
# Used in TCV#70755
|
||||
SNOWFLAKE = rewards.Reward([
|
||||
rewards.Component(
|
||||
target=targets.ShapeLCFSDistance(),
|
||||
transforms=[transforms.SoftPlus(good=0.005, bad=0.05),
|
||||
combiners.SmoothMax(-1)],
|
||||
weight=3),
|
||||
rewards.Component(
|
||||
target=targets.LimitPoint(),
|
||||
transforms=[transforms.Sigmoid(bad=0.2, good=0.1)]),
|
||||
rewards.Component(
|
||||
target=targets.XPointNormalizedFlux(num_points=2),
|
||||
transforms=[transforms.SoftPlus(bad=0.08)],
|
||||
weight=[1] * 2),
|
||||
rewards.Component(
|
||||
target=targets.XPointDistance(num_points=2),
|
||||
transforms=[transforms.Sigmoid(good=0.01, bad=0.15)],
|
||||
weight=[0.5] * 2),
|
||||
rewards.Component(
|
||||
target=targets.XPointFluxGradient(num_points=2),
|
||||
transforms=[transforms.SoftPlus(bad=3)],
|
||||
weight=[0.5] * 2),
|
||||
rewards.Component(
|
||||
target=targets.LegsNormalizedFlux(),
|
||||
transforms=[transforms.Sigmoid(good=0.1, bad=0.3),
|
||||
combiners.SmoothMax(-5)],
|
||||
weight=2),
|
||||
rewards.Component(
|
||||
target=targets.Ip(),
|
||||
transforms=[transforms.SoftPlus(good=500, bad=20000)],
|
||||
weight=2),
|
||||
rewards.Component(
|
||||
target=targets.VoltageOOB(),
|
||||
transforms=[combiners.Mean(), transforms.SoftPlus(bad=1)]),
|
||||
rewards.Component(
|
||||
target=targets.OHCurrentsClose(),
|
||||
transforms=[transforms.ClippedLinear(good=50, bad=1050)]),
|
||||
rewards.Component(
|
||||
name="CurrentsFarFromZero",
|
||||
target=targets.EFCurrents(),
|
||||
transforms=[transforms.Abs(),
|
||||
transforms.SoftPlus(good=100, bad=50),
|
||||
combiners.GeometricMean()]),
|
||||
], combiner=combiners.SmoothMax(-5))
|
||||
|
||||
|
||||
# Used in TCV#70457
|
||||
NEGATIVE_TRIANGULARITY = rewards.Reward([
|
||||
rewards.Component(
|
||||
target=targets.ShapeLCFSDistance(),
|
||||
transforms=[transforms.SoftPlus(good=0.005, bad=0.05),
|
||||
combiners.SmoothMax(-1)],
|
||||
weight=3),
|
||||
rewards.Component(
|
||||
target=targets.ShapeRadius(),
|
||||
transforms=[transforms.SoftPlus(bad=0.04)]),
|
||||
rewards.Component(
|
||||
target=targets.ShapeElongation(),
|
||||
transforms=[transforms.SoftPlus(bad=0.5)]),
|
||||
rewards.Component(
|
||||
target=targets.ShapeTriangularity(),
|
||||
transforms=[transforms.SoftPlus(bad=0.5)]),
|
||||
rewards.Component(
|
||||
target=targets.Diverted(),
|
||||
transforms=[transforms.Equal()]),
|
||||
rewards.Component(
|
||||
target=targets.XPointNormalizedFlux(num_points=2),
|
||||
transforms=[transforms.SoftPlus(bad=0.08)],
|
||||
weight=[1] * 2),
|
||||
rewards.Component(
|
||||
target=targets.XPointDistance(num_points=2),
|
||||
transforms=[transforms.Sigmoid(good=0.02, bad=0.15)],
|
||||
weight=[0.5] * 2),
|
||||
rewards.Component(
|
||||
target=targets.XPointFluxGradient(num_points=2),
|
||||
transforms=[transforms.SoftPlus(bad=3)],
|
||||
weight=[0.5] * 2),
|
||||
rewards.Component(
|
||||
target=targets.Ip(),
|
||||
transforms=[transforms.SoftPlus(good=500, bad=20000)],
|
||||
weight=2),
|
||||
rewards.Component(
|
||||
target=targets.VoltageOOB(),
|
||||
transforms=[combiners.Mean(), transforms.SoftPlus(bad=1)]),
|
||||
rewards.Component(
|
||||
target=targets.OHCurrentsClose(),
|
||||
transforms=[transforms.ClippedLinear(good=50, bad=1050)]),
|
||||
rewards.Component(
|
||||
name="CurrentsFarFromZero",
|
||||
target=targets.EFCurrents(),
|
||||
transforms=[transforms.Abs(),
|
||||
transforms.SoftPlus(good=100, bad=50),
|
||||
combiners.GeometricMean()]),
|
||||
], combiner=combiners.SmoothMax(-0.5))
|
||||
|
||||
|
||||
# Used in TCV#69545
|
||||
DROPLETS = rewards.Reward([
|
||||
rewards.Component(
|
||||
target=targets.R(indices=[0, 1]),
|
||||
transforms=[transforms.Sigmoid(good=0.02, bad=0.5)],
|
||||
weight=[1, 1]),
|
||||
rewards.Component(
|
||||
target=targets.Z(indices=[0, 1]),
|
||||
transforms=[transforms.Sigmoid(good=0.02, bad=0.2)],
|
||||
weight=[1, 1]),
|
||||
rewards.Component(
|
||||
target=targets.Ip(indices=[0, 1]),
|
||||
transforms=[transforms.Sigmoid(good=2000, bad=20000)],
|
||||
weight=[1, 1]),
|
||||
rewards.Component(
|
||||
target=targets.OHCurrentsClose(),
|
||||
transforms=[transforms.ClippedLinear(good=50, bad=1050)]),
|
||||
], combiner=combiners.GeometricMean())
|
||||
40
fusion_tcv/run_loop.py
Normal file
40
fusion_tcv/run_loop.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Run an agent on the environment."""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fusion_tcv import environment
|
||||
from fusion_tcv import trajectory
|
||||
|
||||
|
||||
def run_loop(env: environment.Environment, agent,
|
||||
max_steps: int = 100000) -> trajectory.Trajectory:
|
||||
"""Run an agent."""
|
||||
results = []
|
||||
agent.reset()
|
||||
ts = env.reset()
|
||||
for _ in range(max_steps):
|
||||
obs = ts.observation
|
||||
action = agent.step(ts)
|
||||
ts = env.step(action)
|
||||
results.append(trajectory.Trajectory(
|
||||
measurements=obs["measurements"],
|
||||
references=obs["references"],
|
||||
actions=action,
|
||||
reward=np.array(ts.reward)))
|
||||
if ts.last():
|
||||
break
|
||||
|
||||
return trajectory.Trajectory.stack(results)
|
||||
468
fusion_tcv/shape.py
Normal file
468
fusion_tcv/shape.py
Normal file
@@ -0,0 +1,468 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Utilities for time varying shape control."""
|
||||
|
||||
import copy
|
||||
import enum
|
||||
import random
|
||||
from typing import List, Optional, NamedTuple, Tuple, Union
|
||||
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
from scipy import interpolate
|
||||
|
||||
from fusion_tcv import named_array
|
||||
from fusion_tcv import tcv_common
|
||||
|
||||
|
||||
class Point(NamedTuple):
|
||||
"""A point in r,z coordinates."""
|
||||
r: float
|
||||
z: float
|
||||
|
||||
def to_polar(self) -> "PolarPoint":
|
||||
return PolarPoint(np.arctan2(self.z, self.r),
|
||||
np.sqrt(self.r**2 + self.z**2))
|
||||
|
||||
def __neg__(self):
|
||||
return Point(-self.r, -self.z)
|
||||
|
||||
def __add__(self, pt_or_val: Union["Point", float]):
|
||||
if isinstance(pt_or_val, Point):
|
||||
return Point(self.r + pt_or_val.r, self.z + pt_or_val.z)
|
||||
else:
|
||||
return Point(self.r + pt_or_val, self.z + pt_or_val)
|
||||
|
||||
def __sub__(self, pt_or_val: Union["Point", float]):
|
||||
if isinstance(pt_or_val, Point):
|
||||
return Point(self.r - pt_or_val.r, self.z - pt_or_val.z)
|
||||
else:
|
||||
return Point(self.r - pt_or_val, self.z - pt_or_val)
|
||||
|
||||
def __mul__(self, pt_or_val: Union["Point", float]):
|
||||
if isinstance(pt_or_val, Point):
|
||||
return Point(self.r * pt_or_val.r, self.z * pt_or_val.z)
|
||||
else:
|
||||
return Point(self.r * pt_or_val, self.z * pt_or_val)
|
||||
|
||||
def __truediv__(self, pt_or_val: Union["Point", float]):
|
||||
if isinstance(pt_or_val, Point):
|
||||
return Point(self.r / pt_or_val.r, self.z / pt_or_val.z)
|
||||
else:
|
||||
return Point(self.r / pt_or_val, self.z / pt_or_val)
|
||||
|
||||
__div__ = __truediv__
|
||||
|
||||
|
||||
def dist(p1: Union[Point, np.ndarray], p2: Union[Point, np.ndarray]) -> float:
|
||||
return np.hypot(*(p1 - p2))
|
||||
|
||||
|
||||
ShapePoints = List[Point]
|
||||
|
||||
|
||||
def to_shape_points(array: np.ndarray) -> ShapePoints:
|
||||
return [Point(r, z) for r, z in array]
|
||||
|
||||
|
||||
def center_point(points: ShapePoints) -> Point:
|
||||
return sum(points, Point(0, 0)) / len(points)
|
||||
|
||||
|
||||
class ShapeSide(enum.Enum):
|
||||
LEFT = 0
|
||||
RIGHT = 1
|
||||
NOSHIFT = 2
|
||||
|
||||
|
||||
class PolarPoint(NamedTuple):
|
||||
angle: float
|
||||
dist: float
|
||||
|
||||
def to_point(self) -> Point:
|
||||
return Point(np.cos(self.angle) * self.dist, np.sin(self.angle) * self.dist)
|
||||
|
||||
|
||||
def evenly_spaced_angles(num: int):
|
||||
return np.arange(num) * 2 * np.pi / num
|
||||
|
||||
|
||||
def angle_aligned_dists(points: np.ndarray, angles: np.ndarray) -> np.ndarray:
|
||||
"""Return a new set of points along angles that intersect with the shape."""
|
||||
# TODO(tewalds): Walk the two arrays together for an O(n+m) algorithm instead
|
||||
# of the current O(n*m). This would work as long as they are both sorted
|
||||
# around the radial direction, so the next intersection will be near the last.
|
||||
return np.array([dist_angle_to_surface(points, a) for a in angles])
|
||||
|
||||
|
||||
def angle_aligned_points(points: np.ndarray, num_points: int,
|
||||
origin: Point) -> np.ndarray:
|
||||
"""Given a set of points, return a new space centered at origin."""
|
||||
angles = evenly_spaced_angles(num_points)
|
||||
dists = angle_aligned_dists(points - origin, angles)
|
||||
return np.stack((np.cos(angles) * dists,
|
||||
np.sin(angles) * dists), axis=-1) + origin
|
||||
|
||||
|
||||
def dist_angle_to_surface(points: np.ndarray, angle: float) -> float:
|
||||
"""Distance along a ray to the surface defined by a list of points."""
|
||||
for p1, p2 in zip(points, np.roll(points, 1, axis=0)):
|
||||
d = dist_angle_to_segment(p1, p2, angle)
|
||||
if d is not None:
|
||||
return d
|
||||
raise ValueError(f"Intersecting edge not found for angle: {angle}")
|
||||
|
||||
|
||||
def dist_angle_to_segment(p1, p2, angle: float) -> Optional[float]:
|
||||
"""Distance along a ray from the origin to a segment defined by two points."""
|
||||
x0, y0 = p1[0], p1[1]
|
||||
x1, y1 = p2[0], p2[1]
|
||||
a0, b0 = np.cos(angle), np.sin(angle)
|
||||
a1, b1 = 0, 0
|
||||
# Segment/segment algorithm inspired by https://stackoverflow.com/q/563198
|
||||
denom = (b0 - b1) * (x0 - x1) - (y0 - y1) * (a0 - a1)
|
||||
if denom == 0:
|
||||
return None # Angle parallel to the segment, so can't intersect.
|
||||
xy = (a0 * (y1 - b1) + a1 * (b0 - y1) + x1 * (b1 - b0)) / denom
|
||||
eps = 0.00001 # Allow intersecting slightly beyond the endpoints.
|
||||
if -eps <= xy <= 1 + eps: # Check it hit the segment, not just the line.
|
||||
ab = (y1 * (x0 - a1) + b1 * (x1 - x0) + y0 * (a1 - x1)) / denom
|
||||
if ab > 0: # Otherwise it hit in the reverse direction.
|
||||
# If ab <= 1 then it's within the segment defined above, but given it's
|
||||
# a unit vector with one end at the origin this tells us the distance to
|
||||
# the intersection of an infinite ray out from the origin.
|
||||
return ab
|
||||
return None
|
||||
|
||||
|
||||
def dist_point_to_surface(points: np.ndarray, point: np.ndarray) -> float:
|
||||
"""Distance from a point to the surface defined by a list of points."""
|
||||
return min(dist_point_to_segment(p1, p2, point)
|
||||
for p1, p2 in zip(points, np.roll(points, 1, axis=0)))
|
||||
|
||||
|
||||
def dist_point_to_segment(v: np.ndarray, w: np.ndarray, p: np.ndarray) -> float:
|
||||
"""Return minimum distance between line segment vw and point p."""
|
||||
# Inspired by: https://stackoverflow.com/a/1501725
|
||||
l2 = dist(v, w)**2
|
||||
if l2 == 0.0:
|
||||
return dist(p, v) # v == w case
|
||||
# Consider the line extending the segment, parameterized as v + t (w - v).
|
||||
# We find projection of point p onto the line.
|
||||
# It falls where t = [(p-v) . (w-v)] / |w-v|^2
|
||||
# We clamp t from [0,1] to handle points outside the segment vw.
|
||||
t = max(0, min(1, np.dot(p - v, w - v) / l2))
|
||||
projection = v + t * (w - v) # Projection falls on the segment
|
||||
return dist(p, projection)
|
||||
|
||||
|
||||
def sort_by_angle(points: ShapePoints) -> ShapePoints:
|
||||
center = sum(points, Point(0, 0)) / len(points)
|
||||
return sorted(points, key=lambda p: (p - center).to_polar().angle)
|
||||
|
||||
|
||||
def spline_interpolate_points(
|
||||
points: ShapePoints, num_points: int,
|
||||
x_points: Optional[ShapePoints] = None) -> ShapePoints:
|
||||
"""Interpolate along a spline to give a smooth evenly spaced shape."""
|
||||
ends = []
|
||||
if x_points:
|
||||
# Find the shape points that must allow sharp corners.
|
||||
for xp in x_points:
|
||||
for i, p in enumerate(points):
|
||||
if np.hypot(*(p - xp)) < 0.01:
|
||||
ends.append(i)
|
||||
|
||||
if not ends:
|
||||
# No x-points forcing sharp corners, so use a periodic spline.
|
||||
tck, _ = interpolate.splprep(np.array(points + [points[0]]).T, s=0, per=1)
|
||||
unew = np.arange(num_points) / num_points
|
||||
out = interpolate.splev(unew, tck)
|
||||
assert len(out[0]) == num_points
|
||||
return sort_by_angle(to_shape_points(np.array(out).T))
|
||||
|
||||
# Generate a spline with an shape==x-point at each end.
|
||||
new_pts = []
|
||||
for i, j in zip(ends, ends[1:] + [ends[0]]):
|
||||
pts = points[i:j+1] if i < j else points[i:] + points[:j+1]
|
||||
num_segment_points = np.round((len(pts) - 1) / len(points) * num_points)
|
||||
unew = np.arange(num_segment_points + 1) / num_segment_points
|
||||
tck, _ = interpolate.splprep(np.array(pts).T, s=0)
|
||||
out = interpolate.splev(unew, tck)
|
||||
new_pts += to_shape_points(np.array(out).T)[:-1]
|
||||
if len(new_pts) != num_points:
|
||||
raise AssertionError(
|
||||
f"Generated the wrong number of points: {len(new_pts)} != {num_points}")
|
||||
return sort_by_angle(new_pts)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ParametrizedShape:
|
||||
"""Describes a target shape from the parameter set."""
|
||||
r0: float # Where to put the center along the radial axis.
|
||||
z0: float # Where to put the center along the vertical axis.
|
||||
kappa: float # Elongation of the shape. (0.8, 3)
|
||||
delta: float # Triangulation of the shape. (-1, 1)
|
||||
radius: float # Radius of the shape (0.22, 2.58)
|
||||
lambda_: float # Squareness of the shape. Recommend (0, 0)
|
||||
side: ShapeSide # Whether and which side to shift the shape to.
|
||||
|
||||
@classmethod
|
||||
def uniform_random_shape(
|
||||
cls,
|
||||
r_bounds=(0.8, 0.9),
|
||||
z_bounds=(0, 0.2),
|
||||
kappa_bounds=(1.0, 1.8), # elongation
|
||||
delta_bounds=(-0.5, 0.6), # triangulation
|
||||
radius_bounds=(tcv_common.LIMITER_WIDTH / 2 - 0.04,
|
||||
tcv_common.LIMITER_WIDTH / 2),
|
||||
lambda_bounds=(0, 0), # squareness
|
||||
side=(ShapeSide.LEFT, ShapeSide.RIGHT)):
|
||||
"""Return a random shape."""
|
||||
return cls(
|
||||
r0=np.random.uniform(*r_bounds),
|
||||
z0=np.random.uniform(*z_bounds),
|
||||
kappa=np.random.uniform(*kappa_bounds),
|
||||
delta=np.random.uniform(*delta_bounds),
|
||||
radius=np.random.uniform(*radius_bounds),
|
||||
lambda_=np.random.uniform(*lambda_bounds),
|
||||
side=side if isinstance(side, ShapeSide) else random.choice(side))
|
||||
|
||||
def gen_points(self, num_points: int) -> Tuple[ShapePoints, Point]:
|
||||
"""Generates a set of shape points, return (points, modified (r0, z0))."""
|
||||
r0 = self.r0
|
||||
z0 = self.z0
|
||||
num_warped_points = 32
|
||||
points = np.zeros((num_warped_points, 2))
|
||||
theta = evenly_spaced_angles(num_warped_points)
|
||||
points[:, 0] = r0 + self.radius * np.cos(theta + self.delta * np.sin(theta)
|
||||
- self.lambda_ * np.sin(2 * theta))
|
||||
points[:, 1] = z0 + self.radius * self.kappa * np.sin(theta)
|
||||
if self.side == ShapeSide.LEFT:
|
||||
wall_shift = np.min(points[:, 0]) - tcv_common.INNER_LIMITER_R
|
||||
points[:, 0] -= wall_shift
|
||||
r0 -= wall_shift
|
||||
elif self.side == ShapeSide.RIGHT:
|
||||
wall_shift = np.max(points[:, 0]) - tcv_common.OUTER_LIMITER_R
|
||||
points[:, 0] -= wall_shift
|
||||
r0 -= wall_shift
|
||||
return (spline_interpolate_points(to_shape_points(points), num_points),
|
||||
Point(r0, z0))
|
||||
|
||||
|
||||
def trim_zero_points(points: ShapePoints) -> Optional[ShapePoints]:
|
||||
trimmed = [p for p in points if p.r != 0]
|
||||
return trimmed if trimmed else None
|
||||
|
||||
|
||||
class Diverted(enum.Enum):
|
||||
"""Whether a shape is diverted or not."""
|
||||
ANY = 0
|
||||
LIMITED = 1
|
||||
DIVERTED = 2
|
||||
|
||||
@classmethod
|
||||
def from_refs(cls, references: named_array.NamedArray) -> "Diverted":
|
||||
diverted = (references["diverted", 0] == 1)
|
||||
limited = (references["limited", 0] == 1)
|
||||
if diverted and limited:
|
||||
raise ValueError("Diverted and limited doesn't make sense.")
|
||||
if diverted:
|
||||
return cls.DIVERTED
|
||||
if limited:
|
||||
return cls.LIMITED
|
||||
return cls.ANY
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Shape:
|
||||
"""Full specification of a shape."""
|
||||
params: Optional[ParametrizedShape] = None
|
||||
points: Optional[ShapePoints] = None
|
||||
x_points: Optional[ShapePoints] = None
|
||||
legs: Optional[ShapePoints] = None
|
||||
diverted: Diverted = Diverted.ANY
|
||||
ip: Optional[float] = None
|
||||
limit_point: Optional[Point] = None
|
||||
|
||||
@classmethod
|
||||
def from_references(cls, references: named_array.NamedArray) -> "Shape":
|
||||
"""Extract a Shape from the references."""
|
||||
if any(np.any(references[name] != 0)
|
||||
for name in ("R", "Z", "kappa", "delta", "radius", "lambda")):
|
||||
params = ParametrizedShape(
|
||||
r0=references["R"][0],
|
||||
z0=references["Z"][0],
|
||||
kappa=references["kappa"][0],
|
||||
delta=references["delta"][0],
|
||||
radius=references["radius"][0],
|
||||
lambda_=references["lambda"][0],
|
||||
side=ShapeSide.NOSHIFT)
|
||||
else:
|
||||
params = None
|
||||
|
||||
ip = references["Ip", 0]
|
||||
return cls(
|
||||
params,
|
||||
points=trim_zero_points(points_from_references(references, "shape1")),
|
||||
x_points=trim_zero_points(points_from_references(references,
|
||||
"x_points")),
|
||||
legs=trim_zero_points(points_from_references(references, "legs")),
|
||||
limit_point=trim_zero_points(points_from_references(
|
||||
references, "limit_point")[0:1]),
|
||||
diverted=Diverted.from_refs(references),
|
||||
ip=float(ip) if ip != 0 else None)
|
||||
|
||||
def gen_references(self) -> named_array.NamedArray:
|
||||
"""Return the references for the parametrized shape."""
|
||||
refs = tcv_common.REF_RANGES.new_named_array()
|
||||
|
||||
if self.ip is not None:
|
||||
refs["Ip", 0] = self.ip
|
||||
|
||||
refs["diverted", 0] = 1 if self.diverted == Diverted.DIVERTED else 0
|
||||
refs["limited", 0] = 1 if self.diverted == Diverted.LIMITED else 0
|
||||
|
||||
if self.params is not None:
|
||||
refs["R", 0] = self.params.r0
|
||||
refs["Z", 0] = self.params.z0
|
||||
refs["kappa", 0] = self.params.kappa
|
||||
refs["delta", 0] = self.params.delta
|
||||
refs["radius", 0] = self.params.radius
|
||||
refs["lambda", 0] = self.params.lambda_
|
||||
|
||||
if self.points is not None:
|
||||
points = np.array(self.points)
|
||||
assert refs.names.count("shape_r") >= points.shape[0]
|
||||
refs["shape_r", :points.shape[0]] = points[:, 0]
|
||||
refs["shape_z", :points.shape[0]] = points[:, 1]
|
||||
|
||||
if self.x_points is not None:
|
||||
x_points = np.array(self.x_points)
|
||||
assert refs.names.count("x_points_r") >= x_points.shape[0]
|
||||
refs["x_points_r", :x_points.shape[0]] = x_points[:, 0]
|
||||
refs["x_points_z", :x_points.shape[0]] = x_points[:, 1]
|
||||
|
||||
if self.legs is not None:
|
||||
legs = np.array(self.legs)
|
||||
assert refs.names.count("legs_r") >= legs.shape[0]
|
||||
refs["legs_r", :legs.shape[0]] = legs[:, 0]
|
||||
refs["legs_z", :legs.shape[0]] = legs[:, 1]
|
||||
|
||||
if self.limit_point is not None:
|
||||
refs["limit_point_r", 0] = self.limit_point.r
|
||||
refs["limit_point_z", 0] = self.limit_point.z
|
||||
|
||||
return refs
|
||||
|
||||
def canonical(self) -> "Shape":
|
||||
"""Return a canonical shape with a fixed number of points and params."""
|
||||
num_points = tcv_common.REF_RANGES.count("shape_r")
|
||||
out = copy.deepcopy(self)
|
||||
|
||||
if out.points is None:
|
||||
if out.params is None:
|
||||
raise ValueError("Can't canonicalize with no params or points.")
|
||||
out.points, center = out.params.gen_points(num_points)
|
||||
out.params.r0 = center.r
|
||||
out.params.z0 = center.z
|
||||
out.params.side = ShapeSide.NOSHIFT
|
||||
else:
|
||||
out.points = spline_interpolate_points(
|
||||
out.points, num_points, out.x_points or [])
|
||||
if out.params:
|
||||
out.params.side = ShapeSide.NOSHIFT
|
||||
else:
|
||||
# Copied from FGE. Details: https://doi.org/10.1017/S0022377815001270
|
||||
top = max(out.points, key=lambda p: p.z)
|
||||
left = min(out.points, key=lambda p: p.r)
|
||||
right = max(out.points, key=lambda p: p.r)
|
||||
bottom = min(out.points, key=lambda p: p.z)
|
||||
center = Point((left.r + right.r) / 2,
|
||||
(top.z + bottom.z) / 2)
|
||||
|
||||
radius = (right.r - left.r) / 2
|
||||
kappa = (top.z - bottom.z) / (right.r - left.r)
|
||||
delta_lower = (center.r - bottom.r) / radius # upper triangularitiy
|
||||
delta_upper = (center.r - top.r) / radius # lower triangularity
|
||||
delta = (delta_lower + delta_upper) / 2
|
||||
|
||||
out.params = ParametrizedShape(
|
||||
r0=center.r, z0=center.z, radius=radius, kappa=kappa, delta=delta,
|
||||
lambda_=0, side=ShapeSide.NOSHIFT)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def points_from_references(
|
||||
references: named_array.NamedArray, key: str = "shape1",
|
||||
num: Optional[int] = None) -> ShapePoints:
|
||||
points = np.array([references[f"{key}_r"], references[f"{key}_z"]]).T
|
||||
if num is not None:
|
||||
points = points[:num]
|
||||
return to_shape_points(points)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReferenceTimeSlice:
|
||||
shape: Shape
|
||||
time: float
|
||||
hold: Optional[float] = None # Absolute time.
|
||||
|
||||
def __post_init__(self):
|
||||
if self.hold is None:
|
||||
self.hold = self.time
|
||||
|
||||
|
||||
def canonicalize_reference_series(
|
||||
time_slices: List[ReferenceTimeSlice]) -> List[ReferenceTimeSlice]:
|
||||
"""Canonicalize a full sequence of time slices."""
|
||||
outputs = []
|
||||
for ref in time_slices:
|
||||
ref_shape = ref.shape.canonical()
|
||||
|
||||
prev = outputs[-1] if outputs else None
|
||||
if prev is not None and prev.hold + tcv_common.DT < ref.time:
|
||||
leg_diff = len(ref_shape.legs or []) != len(prev.shape.legs or [])
|
||||
xp_diff = len(ref_shape.x_points or []) != len(prev.shape.x_points or [])
|
||||
div_diff = ref_shape.diverted != prev.shape.diverted
|
||||
limit_diff = (
|
||||
bool(ref_shape.limit_point and ref_shape.limit_point.r > 0) !=
|
||||
bool(prev.shape.limit_point and prev.shape.limit_point.r > 0))
|
||||
if leg_diff or xp_diff or div_diff or limit_diff:
|
||||
# Try not to interpolate between a real x-point and a non-existent
|
||||
# x-point. Non-existent x-points are represented as being at the
|
||||
# origin, i.e. out to the left of the vessel, and could be interpolated
|
||||
# into place, but that's weird, so better to pop it into existence by
|
||||
# adding an extra frame one before with the new shape targets.
|
||||
# This doesn't handle the case of multiple points appearing/disappearing
|
||||
# out of order, or of one moving while the other disappears.
|
||||
outputs.append(
|
||||
ReferenceTimeSlice(
|
||||
time=ref.time - tcv_common.DT,
|
||||
shape=Shape(
|
||||
ip=ref_shape.ip,
|
||||
params=ref_shape.params,
|
||||
points=ref_shape.points,
|
||||
x_points=(prev.shape.x_points
|
||||
if xp_diff else ref_shape.x_points),
|
||||
legs=(prev.shape.legs if leg_diff else ref_shape.legs),
|
||||
limit_point=(prev.shape.limit_point
|
||||
if limit_diff else ref_shape.limit_point),
|
||||
diverted=(prev.shape.diverted
|
||||
if div_diff else ref_shape.diverted))))
|
||||
|
||||
outputs.append(ReferenceTimeSlice(ref_shape, ref.time, ref.hold))
|
||||
|
||||
return outputs
|
||||
68
fusion_tcv/shapes_known.py
Normal file
68
fusion_tcv/shapes_known.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A set of known shapes."""
|
||||
|
||||
from fusion_tcv import shape
|
||||
from fusion_tcv import tcv_common
|
||||
|
||||
|
||||
SHAPE_70166_0450 = shape.Shape(
|
||||
ip=-120000,
|
||||
params=shape.ParametrizedShape(
|
||||
r0=0.89,
|
||||
z0=0.25,
|
||||
kappa=1.4,
|
||||
delta=0.25,
|
||||
radius=0.25,
|
||||
lambda_=0,
|
||||
side=shape.ShapeSide.NOSHIFT),
|
||||
limit_point=shape.Point(tcv_common.INNER_LIMITER_R, 0.25),
|
||||
diverted=shape.Diverted.LIMITED)
|
||||
|
||||
|
||||
SHAPE_70166_0872 = shape.Shape(
|
||||
ip=-110000,
|
||||
params=shape.ParametrizedShape(
|
||||
r0=0.8796,
|
||||
z0=0.2339,
|
||||
kappa=1.2441,
|
||||
delta=0.2567,
|
||||
radius=0.2390,
|
||||
lambda_=0,
|
||||
side=shape.ShapeSide.NOSHIFT,
|
||||
),
|
||||
points=[ # 20 points
|
||||
shape.Point(0.6299, 0.1413),
|
||||
shape.Point(0.6481, 0.0577),
|
||||
shape.Point(0.6804, -0.0087),
|
||||
shape.Point(0.7286, -0.0513),
|
||||
shape.Point(0.7931, -0.0660),
|
||||
shape.Point(0.8709, -0.0513),
|
||||
shape.Point(0.9543, -0.0087),
|
||||
shape.Point(1.0304, 0.0577),
|
||||
shape.Point(1.0844, 0.1413),
|
||||
shape.Point(1.1040, 0.2340),
|
||||
shape.Point(1.0844, 0.3267),
|
||||
shape.Point(1.0304, 0.4103),
|
||||
shape.Point(0.9543, 0.4767),
|
||||
shape.Point(0.8709, 0.5193),
|
||||
shape.Point(0.7931, 0.5340),
|
||||
shape.Point(0.7286, 0.5193),
|
||||
shape.Point(0.6804, 0.4767),
|
||||
shape.Point(0.6481, 0.4103),
|
||||
shape.Point(0.6299, 0.3267),
|
||||
shape.Point(0.6240, 0.2340),
|
||||
],
|
||||
limit_point=shape.Point(tcv_common.INNER_LIMITER_R, 0.2339),
|
||||
diverted=shape.Diverted.LIMITED)
|
||||
622
fusion_tcv/targets.py
Normal file
622
fusion_tcv/targets.py
Normal file
File diff suppressed because it is too large
Load Diff
273
fusion_tcv/tcv_common.py
Normal file
273
fusion_tcv/tcv_common.py
Normal file
@@ -0,0 +1,273 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Constants and general tooling for TCV plant."""
|
||||
|
||||
import collections
|
||||
from typing import Sequence, Text
|
||||
from dm_env import specs
|
||||
import numpy as np
|
||||
|
||||
from fusion_tcv import named_array
|
||||
|
||||
|
||||
DT = 1e-4 # ie 10kHz
|
||||
|
||||
|
||||
# Below are general input/output specifications used for controllers that are
|
||||
# run on hardware. This interface corresponds to the so called "KH hybrid"
|
||||
# controller specification that is used in various experiments by EPFL. Hence,
|
||||
# this interface definition contains measurements and actions not used in our
|
||||
# tasks.
|
||||
|
||||
# Number of actions the environment is exposing. Includes dummy (FAST) action.
|
||||
NUM_ACTIONS = 20
|
||||
|
||||
# Number of actuated coils in sim (without the dummy action coil).
|
||||
NUM_COILS_ACTUATED = 19
|
||||
|
||||
# Current and voltage limits by coil type
|
||||
# Note this are the limits used in the environments and are different
|
||||
# from the 'machine engineering limits' (as used/exposed by FGE).
|
||||
# We apply a safety factor (=< 1.0) to the engineering limits.
|
||||
CURRENT_SAFETY_FACTOR = 0.8
|
||||
ENV_COIL_MAX_CURRENTS = collections.OrderedDict(
|
||||
E=7500*CURRENT_SAFETY_FACTOR,
|
||||
F=7500*CURRENT_SAFETY_FACTOR,
|
||||
OH=26000*CURRENT_SAFETY_FACTOR,
|
||||
DUMMY=2000*CURRENT_SAFETY_FACTOR,
|
||||
G=2000*CURRENT_SAFETY_FACTOR)
|
||||
|
||||
|
||||
# The g-coil has a saturation voltage that is tunable on a shot-by-shot basis.
|
||||
# There is a deadband, where an action with absolute value of less than 8% of
|
||||
# the saturation voltage is treated as zero.
|
||||
ENV_G_COIL_SATURATION_VOLTAGE = 300
|
||||
ENV_G_COIL_DEADBAND = ENV_G_COIL_SATURATION_VOLTAGE * 0.08
|
||||
|
||||
VOLTAGE_SAFETY_FACTOR = 1.0
|
||||
ENV_COIL_MAX_VOLTAGE = collections.OrderedDict(
|
||||
E=1400*VOLTAGE_SAFETY_FACTOR,
|
||||
F=2200*VOLTAGE_SAFETY_FACTOR,
|
||||
OH=1400*VOLTAGE_SAFETY_FACTOR,
|
||||
DUMMY=400*VOLTAGE_SAFETY_FACTOR,
|
||||
# This value is also used to clip values for the internal controller,
|
||||
# and also to set the deadband voltage.
|
||||
G=ENV_G_COIL_SATURATION_VOLTAGE)
|
||||
|
||||
|
||||
# Ordered actions send by a controller to the TCV.
|
||||
TCV_ACTIONS = (
|
||||
'E_001', 'E_002', 'E_003', 'E_004', 'E_005', 'E_006', 'E_007', 'E_008',
|
||||
'F_001', 'F_002', 'F_003', 'F_004', 'F_005', 'F_006', 'F_007', 'F_008',
|
||||
'OH_001', 'OH_002',
|
||||
'DUMMY_001', # GAS, ignored by TCV.
|
||||
'G_001' # FAST
|
||||
)
|
||||
TCV_ACTION_INDICES = {n: i for i, n in enumerate(TCV_ACTIONS)}
|
||||
|
||||
TCV_ACTION_TYPES = collections.OrderedDict(
|
||||
E=8,
|
||||
F=8,
|
||||
OH=2,
|
||||
DUMMY=1,
|
||||
G=1,
|
||||
)
|
||||
|
||||
# Map the TCV actions to ranges of indices in the array.
|
||||
TCV_ACTION_RANGES = named_array.NamedRanges(TCV_ACTION_TYPES)
|
||||
|
||||
# The voltages seem not to be centered at 0, but instead near these values:
|
||||
TCV_ACTION_OFFSETS = {
|
||||
'E_001': 6.79,
|
||||
'E_002': -10.40,
|
||||
'E_003': -1.45,
|
||||
'E_004': 0.18,
|
||||
'E_005': 11.36,
|
||||
'E_006': -0.95,
|
||||
'E_007': -4.28,
|
||||
'E_008': 44.22,
|
||||
'F_001': 38.49,
|
||||
'F_002': -2.94,
|
||||
'F_003': 5.58,
|
||||
'F_004': 1.09,
|
||||
'F_005': -36.63,
|
||||
'F_006': -9.18,
|
||||
'F_007': 5.34,
|
||||
'F_008': 10.53,
|
||||
'OH_001': -53.63,
|
||||
'OH_002': -14.76,
|
||||
}
|
||||
|
||||
TCV_ACTION_DELAYS = {
|
||||
'E': [0.0005] * 8,
|
||||
'F': [0.0005] * 8,
|
||||
'OH': [0.0005] * 2,
|
||||
'G': [0.0001],
|
||||
}
|
||||
|
||||
|
||||
# Ordered measurements and their dimensions from to the TCV controller specs.
|
||||
TCV_MEASUREMENTS = collections.OrderedDict(
|
||||
clint_vloop=1, # Flux loop 1
|
||||
clint_rvloop=37, # Difference of flux between loops 2-38 and flux loop 1
|
||||
bm=38, # Magnetic field probes
|
||||
IE=8, # E-coil currents
|
||||
IF=8, # F-coil currents
|
||||
IOH=2, # OH-coil currents
|
||||
Bdot=20, # Selection of 20 time-derivatives of magnetic field probes (bm).
|
||||
DIOH=1, # OH-coil currents difference: OH(0) - OH(1).
|
||||
FIR_FRINGE=1, # Not used, ignore.
|
||||
IG=1, # G-coil current
|
||||
ONEMM=1, # Not used, ignore
|
||||
vloop=1, # Flux loop 1 derivative
|
||||
IPHI=1, # Current through the Toroidal Field coils. Constant. Ignore.
|
||||
)
|
||||
|
||||
NUM_MEASUREMENTS = sum(TCV_MEASUREMENTS.values())
|
||||
# map the TCV measurements to ranges of indices in the array
|
||||
TCV_MEASUREMENT_RANGES = named_array.NamedRanges(TCV_MEASUREMENTS)
|
||||
|
||||
# Several of the measurement probes for the rvloops are broken. Add an extra key
|
||||
# that allows us to only grab the usable ones
|
||||
BROKEN_RVLOOP_IDXS = [9, 10, 11]
|
||||
|
||||
TCV_MEASUREMENT_RANGES.set_range('clint_rvloop_usable', [
|
||||
idx for i, idx in enumerate(TCV_MEASUREMENT_RANGES['clint_rvloop'])
|
||||
if i not in BROKEN_RVLOOP_IDXS])
|
||||
|
||||
TCV_COIL_CURRENTS_INDEX = [
|
||||
*TCV_MEASUREMENT_RANGES['IE'],
|
||||
*TCV_MEASUREMENT_RANGES['IF'],
|
||||
*TCV_MEASUREMENT_RANGES['IOH'],
|
||||
*TCV_MEASUREMENT_RANGES['IPHI'], # In place of DUMMY.
|
||||
*TCV_MEASUREMENT_RANGES['IG'],
|
||||
]
|
||||
|
||||
|
||||
# References for what we want the agent to accomplish.
|
||||
REF_RANGES = named_array.NamedRanges({
|
||||
'R': 2,
|
||||
'Z': 2,
|
||||
'Ip': 2,
|
||||
'kappa': 2,
|
||||
'delta': 2,
|
||||
'radius': 2,
|
||||
'lambda': 2,
|
||||
'diverted': 2, # bool, must be diverted
|
||||
'limited': 2, # bool, must be limited
|
||||
'shape_r': 32,
|
||||
'shape_z': 32,
|
||||
'x_points_r': 8,
|
||||
'x_points_z': 8,
|
||||
'legs_r': 16, # Use for diverted/snowflake
|
||||
'legs_z': 16,
|
||||
'limit_point_r': 2,
|
||||
'limit_point_z': 2,
|
||||
})
|
||||
|
||||
# Environments should use a consistent datatype for interacting with agents.
|
||||
ENVIRONMENT_DATA_TYPE = np.float64
|
||||
|
||||
|
||||
def observation_spec():
|
||||
"""Observation spec for all TCV environments."""
|
||||
return {
|
||||
'references':
|
||||
specs.Array(
|
||||
shape=(REF_RANGES.size,),
|
||||
dtype=ENVIRONMENT_DATA_TYPE,
|
||||
name='references'),
|
||||
'measurements':
|
||||
specs.Array(
|
||||
shape=(TCV_MEASUREMENT_RANGES.size,),
|
||||
dtype=ENVIRONMENT_DATA_TYPE,
|
||||
name='measurements'),
|
||||
'last_action':
|
||||
specs.Array(
|
||||
shape=(TCV_ACTION_RANGES.size,),
|
||||
dtype=ENVIRONMENT_DATA_TYPE,
|
||||
name='last_action'),
|
||||
}
|
||||
|
||||
|
||||
def measurements_to_dict(measurements):
|
||||
"""Converts a single measurement vector or a time series to a dict.
|
||||
|
||||
Args:
|
||||
measurements: A single measurement of size `NUM_MEASUREMENTS` or a time
|
||||
series, where the batch dimension is last, shape: (NUM_MEASUREMENTS, t).
|
||||
|
||||
Returns:
|
||||
A dict mapping keys `TCV_MEASUREMENTS` to the corresponding measurements.
|
||||
|
||||
"""
|
||||
assert measurements.shape[0] == NUM_MEASUREMENTS
|
||||
measurements_dict = collections.OrderedDict()
|
||||
index = 0
|
||||
for key, dim in TCV_MEASUREMENTS.items():
|
||||
measurements_dict[key] = measurements[index:index + dim, ...]
|
||||
index += dim
|
||||
return measurements_dict
|
||||
|
||||
|
||||
def dict_to_measurement(measurement_dict):
|
||||
"""Converts a single measurement dict to a vector or time series.
|
||||
|
||||
Args:
|
||||
measurement_dict: A dict with the measurement keys containing np arrays of
|
||||
size (meas_size, ...). The inner sizes all have to be the same.
|
||||
|
||||
Returns:
|
||||
An array of size (num_measurements, ...)
|
||||
|
||||
"""
|
||||
assert len(measurement_dict) == len(TCV_MEASUREMENTS)
|
||||
# Grab the shape of the first array.
|
||||
shape = measurement_dict['clint_vloop'].shape
|
||||
|
||||
out_shape = list(shape)
|
||||
out_shape[0] = NUM_MEASUREMENTS
|
||||
out_shape = tuple(out_shape)
|
||||
measurements = np.zeros((out_shape))
|
||||
index = 0
|
||||
for key, dim in TCV_MEASUREMENTS.items():
|
||||
dim = TCV_MEASUREMENTS[key]
|
||||
measurements[index:index + dim, ...] = measurement_dict[key]
|
||||
index += dim
|
||||
return measurements
|
||||
|
||||
|
||||
def action_spec():
|
||||
return get_coil_spec(TCV_ACTIONS, ENV_COIL_MAX_VOLTAGE, ENVIRONMENT_DATA_TYPE)
|
||||
|
||||
|
||||
def get_coil_spec(coil_names: Sequence[Text],
|
||||
spec_mapping,
|
||||
dtype=ENVIRONMENT_DATA_TYPE) -> specs.BoundedArray:
|
||||
"""Maps specs indexed by coil type to coils given their type."""
|
||||
coil_max, coil_min = [], []
|
||||
for name in coil_names:
|
||||
# Coils names are <coil_type>_<coil_number>
|
||||
coil_type, _ = name.split('_')
|
||||
coil_max.append(spec_mapping[coil_type])
|
||||
coil_min.append(-spec_mapping[coil_type])
|
||||
return specs.BoundedArray(
|
||||
shape=(len(coil_names),), dtype=dtype, minimum=coil_min, maximum=coil_max)
|
||||
|
||||
|
||||
INNER_LIMITER_R = 0.62400001
|
||||
OUTER_LIMITER_R = 1.14179182
|
||||
LIMITER_WIDTH = OUTER_LIMITER_R - INNER_LIMITER_R
|
||||
LIMITER_RADIUS = LIMITER_WIDTH / 2
|
||||
VESSEL_CENTER_R = INNER_LIMITER_R + LIMITER_RADIUS
|
||||
102
fusion_tcv/terminations.py
Normal file
102
fusion_tcv/terminations.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Terminations for the fusion environment."""
|
||||
|
||||
import abc
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fusion_tcv import fge_state
|
||||
from fusion_tcv import tcv_common
|
||||
|
||||
|
||||
class Abstract(abc.ABC):
|
||||
"""Abstract reward class."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def terminate(self, state: fge_state.FGEState) -> Optional[str]:
|
||||
"""Returns a reason if the situation should be considered a termination."""
|
||||
|
||||
|
||||
class CoilCurrentSaturation(Abstract):
|
||||
"""Terminates if the coils have saturated their current."""
|
||||
|
||||
def terminate(self, state: fge_state.FGEState) -> Optional[str]:
|
||||
# Coil currents are checked by type, independent of the order.
|
||||
for coil_type, max_current in tcv_common.ENV_COIL_MAX_CURRENTS.items():
|
||||
if coil_type == "DUMMY":
|
||||
continue
|
||||
currents = state.get_coil_currents_by_type(coil_type)
|
||||
if (np.abs(currents) > max_current).any():
|
||||
return (f"CoilCurrentSaturation: {coil_type}, max: {max_current}, "
|
||||
"real: " + ", ".join(f"{c:.1f}" for c in currents))
|
||||
return None
|
||||
|
||||
|
||||
class OHTooDifferent(Abstract):
|
||||
"""Terminates if the coil currents are too far apart from one another."""
|
||||
|
||||
def __init__(self, max_diff: float):
|
||||
self._max_diff = max_diff
|
||||
|
||||
def terminate(self, state: fge_state.FGEState) -> Optional[str]:
|
||||
oh_coil_currents = state.get_coil_currents_by_type("OH")
|
||||
assert len(oh_coil_currents) == 2
|
||||
oh_current_abs = abs(oh_coil_currents[0] - oh_coil_currents[1])
|
||||
if oh_current_abs > self._max_diff:
|
||||
return ("OHTooDifferent: currents: "
|
||||
f"({oh_coil_currents[0]:.0f}, {oh_coil_currents[1]:.0f}), "
|
||||
f"diff: {oh_current_abs:.0f}, max: {self._max_diff}")
|
||||
return None
|
||||
|
||||
|
||||
class IPTooLow(Abstract):
|
||||
"""Terminates if the magnitude of Ip in any component is too low."""
|
||||
|
||||
def __init__(self, singlet_threshold: float, droplet_threshold: float):
|
||||
self._singlet_threshold = singlet_threshold
|
||||
self._droplet_threshold = droplet_threshold
|
||||
|
||||
def terminate(self, state: fge_state.FGEState) -> Optional[str]:
|
||||
_, _, ip_d = state.rzip_d
|
||||
if len(ip_d) == 1:
|
||||
if ip_d[0] > self._singlet_threshold: # Sign due to negative Ip.
|
||||
return f"IPTooLow: Singlet, {ip_d[0]:.0f}"
|
||||
return None
|
||||
else:
|
||||
if max(ip_d) > self._droplet_threshold: # Sign due to negative Ip.
|
||||
return f"IPTooLow: Components: {ip_d[0]:.0f}, {ip_d[1]:.0f}"
|
||||
return None
|
||||
|
||||
|
||||
class AnyTermination(Abstract):
|
||||
"""Terminates if any of conditions are met."""
|
||||
|
||||
def __init__(self, terminators: List[Abstract]):
|
||||
self._terminators = terminators
|
||||
|
||||
def terminate(self, state: fge_state.FGEState) -> Optional[str]:
|
||||
for terminator in self._terminators:
|
||||
term = terminator.terminate(state)
|
||||
if term:
|
||||
return term
|
||||
return None
|
||||
|
||||
|
||||
CURRENT_OH_IP = AnyTermination([
|
||||
CoilCurrentSaturation(),
|
||||
OHTooDifferent(max_diff=4000),
|
||||
IPTooLow(singlet_threshold=-60000, droplet_threshold=-25000),
|
||||
])
|
||||
40
fusion_tcv/trajectory.py
Normal file
40
fusion_tcv/trajectory.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""A trajectory for an episode."""
|
||||
|
||||
from typing import List
|
||||
|
||||
import dataclasses
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Trajectory:
|
||||
"""A trajectory of actions/obs for an episode."""
|
||||
measurements: np.ndarray
|
||||
references: np.ndarray
|
||||
reward: np.ndarray
|
||||
actions: np.ndarray
|
||||
|
||||
@classmethod
|
||||
def stack(cls, series: List["Trajectory"]) -> "Trajectory":
|
||||
"""Stack a series of trajectories, adding a trailing time dimension."""
|
||||
values = {k: np.empty(v.shape + (len(series),))
|
||||
for k, v in dataclasses.asdict(series[0]).items()
|
||||
if v is not None}
|
||||
for i, ts in enumerate(series):
|
||||
for k, v in values.items():
|
||||
v[..., i] = getattr(ts, k)
|
||||
out = cls(**values)
|
||||
return out
|
||||
183
fusion_tcv/transforms.py
Normal file
183
fusion_tcv/transforms.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Transforms from actual/target to rewards."""
|
||||
|
||||
import abc
|
||||
import math
|
||||
from typing import List, Optional
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
# Comparison of some of the transforms:
|
||||
# Order is NegExp, SoftPlus, Sigmoid.
|
||||
# Over range of good to bad:
|
||||
# https://www.wolframalpha.com/input/?i=plot+e%5E%28x*ln%280.1%29%29%2C2%2F%281%2Be%5E%28x*ln%2819%29%29%29%2C1%2F%281%2Be%5E%28-+%28-ln%2819%29+-+%28x-1%29*%282*ln%2819%29%29%29%29%29+from+x%3D0+to+2
|
||||
# When close to good:
|
||||
# https://www.wolframalpha.com/input/?i=plot+e%5E%28x*ln%280.1%29%29%2C2%2F%281%2Be%5E%28x*ln%2819%29%29%29%2C1%2F%281%2Be%5E%28-+%28-ln%2819%29+-+%28x-1%29*%282*ln%2819%29%29%29%29%29+from+x%3D0+to+0.2
|
||||
|
||||
|
||||
class AbstractTransform(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, errors: List[float]) -> List[float]:
|
||||
"""Transforms target errors into rewards."""
|
||||
|
||||
@property
|
||||
def outputs(self) -> Optional[int]:
|
||||
return None
|
||||
|
||||
|
||||
def clip(value: float, low: float, high: float) -> float:
|
||||
"""Clip a value to the range of low - high."""
|
||||
if math.isnan(value):
|
||||
return value
|
||||
assert low <= high
|
||||
return max(low, min(high, value))
|
||||
|
||||
|
||||
def scale(v: float, a: float, b: float, c: float, d: float) -> float:
|
||||
"""Scale a value, v on a line with anchor points a,b to new anchors c,d."""
|
||||
v01 = (v - a) / (b - a)
|
||||
return c - v01 * (c - d)
|
||||
|
||||
|
||||
def logistic(v: float) -> float:
|
||||
"""Standard logistic, asymptoting to 0 and 1."""
|
||||
v = clip(v, -50, 50) # Improve numerical stability.
|
||||
return 1 / (1 + math.exp(-v))
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Equal(AbstractTransform):
|
||||
"""Returns 1 if the error is 0 and not_equal_val otherwise."""
|
||||
not_equal_val: float = 0
|
||||
|
||||
def __call__(self, errors: List[float]) -> List[float]:
|
||||
out = []
|
||||
for err in errors:
|
||||
if math.isnan(err):
|
||||
out.append(err)
|
||||
elif err == 0:
|
||||
out.append(1)
|
||||
else:
|
||||
out.append(self.not_equal_val)
|
||||
return out
|
||||
|
||||
|
||||
class Abs(AbstractTransform):
|
||||
"""Take the absolue value of the error. Does not guarantee 0-1."""
|
||||
|
||||
@staticmethod
|
||||
def __call__(errors: List[float]) -> List[float]:
|
||||
return [abs(err) for err in errors]
|
||||
|
||||
|
||||
class Neg(AbstractTransform):
|
||||
"""Negate the error. Does not guarantee 0-1."""
|
||||
|
||||
@staticmethod
|
||||
def __call__(errors: List[float]) -> List[float]:
|
||||
return [-err for err in errors]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Pow(AbstractTransform):
|
||||
"""Return a power of the error. Does not guarantee 0-1."""
|
||||
pow: float
|
||||
|
||||
def __call__(self, errors: List[float]) -> List[float]:
|
||||
return [err**self.pow for err in errors]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Log(AbstractTransform):
|
||||
"""Return a log of the error. Does not guarantee 0-1."""
|
||||
eps: float = 1e-4
|
||||
|
||||
def __call__(self, errors: List[float]) -> List[float]:
|
||||
return [math.log(err + self.eps) for err in errors]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ClippedLinear(AbstractTransform):
|
||||
"""Scales and clips errors, bad to 0, good to 1. If good=0, this is a relu."""
|
||||
bad: float
|
||||
good: float = 0
|
||||
|
||||
def __call__(self, errors: List[float]) -> List[float]:
|
||||
return [clip(scale(err, self.bad, self.good, 0, 1), 0, 1)
|
||||
for err in errors]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SoftPlus(AbstractTransform):
|
||||
"""Scales and clips errors, bad to 0.1, good to 1, asymptoting to 0.
|
||||
|
||||
Based on the lower half of the logistic instead of the standard softplus as
|
||||
we want it to be bounded from 0 to 1, with the good value being exactly 1.
|
||||
Various constants can be chosen to get the softplus to give the desired
|
||||
properties, but this is much simpler.
|
||||
"""
|
||||
bad: float
|
||||
good: float = 0
|
||||
|
||||
# Constant to set the sharpness/slope of the softplus.
|
||||
# Default was chosen such that the good/bad have 1 and 0.1 reward:
|
||||
# https://www.wolframalpha.com/input/?i=plot+2%2F%281%2Be%5E%28x*ln%2819%29%29%29+from+x%3D0+to+2
|
||||
low: float = -math.log(19) # -2.9444389791664403
|
||||
|
||||
def __call__(self, errors: List[float]) -> List[float]:
|
||||
return [clip(2 * logistic(scale(e, self.bad, self.good, self.low, 0)), 0, 1)
|
||||
for e in errors]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class NegExp(AbstractTransform):
|
||||
"""Scales and clips errors, bad to 0.1, good to 1, asymptoting to 0.
|
||||
|
||||
This scales the reward in an exponential space. This means there is a sharp
|
||||
gradient toward reaching the value of good, flattening out at the value of
|
||||
bad. This can be useful for a reward that gives meaningful signal far away,
|
||||
but still have a sharp gradient near the true target.
|
||||
"""
|
||||
bad: float
|
||||
good: float = 0
|
||||
|
||||
# Constant to set the sharpness/slope of the exponential.
|
||||
# Default was chosen such that the good/bad have 1 and 0.1 reward:
|
||||
# https://www.wolframalpha.com/input/?i=plot+e%5E%28x*ln%280.1%29%29+from+x%3D0+to+2
|
||||
low: float = -math.log(0.1)
|
||||
|
||||
def __call__(self, errors: List[float]) -> List[float]:
|
||||
return [clip(math.exp(-scale(e, self.bad, self.good, self.low, 0)), 0, 1)
|
||||
for e in errors]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Sigmoid(AbstractTransform):
|
||||
"""Scales and clips errors, bad to 0.05, good to 0.95, asymptoting to 0-1."""
|
||||
good: float
|
||||
bad: float
|
||||
|
||||
# Constants to set the sharpness/slope of the sigmoid.
|
||||
# Defaults were chosen such that the good/bad have 0.95 and 0.05 reward:
|
||||
# https://www.wolframalpha.com/input/?i=plot+1%2F%281%2Be%5E%28-+%28-ln%2819%29+-+%28x-1%29*%282*ln%2819%29%29%29%29%29+from+x%3D0+to+2
|
||||
high: float = math.log(19) # +2.9444389791664403
|
||||
low: float = -math.log(19) # -2.9444389791664403
|
||||
|
||||
def __call__(self, errors: List[float]) -> List[float]:
|
||||
return [logistic(scale(err, self.bad, self.good, self.low, self.high))
|
||||
for err in errors]
|
||||
|
||||
162
fusion_tcv/transforms_test.py
Normal file
162
fusion_tcv/transforms_test.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS-IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for transforms."""
|
||||
|
||||
import math
|
||||
|
||||
from absl.testing import absltest
|
||||
from fusion_tcv import transforms
|
||||
|
||||
|
||||
NAN = float("nan")
|
||||
|
||||
|
||||
class TransformsTest(absltest.TestCase):
|
||||
|
||||
def assertNan(self, value: float):
|
||||
self.assertTrue(math.isnan(value))
|
||||
|
||||
def test_clip(self):
|
||||
self.assertEqual(transforms.clip(-1, 0, 1), 0)
|
||||
self.assertEqual(transforms.clip(5, 0, 1), 1)
|
||||
self.assertEqual(transforms.clip(0.5, 0, 1), 0.5)
|
||||
self.assertNan(transforms.clip(NAN, 0, 1))
|
||||
|
||||
def test_scale(self):
|
||||
self.assertEqual(transforms.scale(0, 0, 0.5, 0, 1), 0)
|
||||
self.assertEqual(transforms.scale(0.125, 0, 0.5, 0, 1), 0.25)
|
||||
self.assertEqual(transforms.scale(0.25, 0, 0.5, 0, 1), 0.5)
|
||||
self.assertEqual(transforms.scale(0.5, 0, 0.5, 0, 1), 1)
|
||||
self.assertEqual(transforms.scale(1, 0, 0.5, 0, 1), 2)
|
||||
self.assertEqual(transforms.scale(-1, 0, 0.5, 0, 1), -2)
|
||||
self.assertEqual(transforms.scale(0.5, 1, 0, 0, 1), 0.5)
|
||||
self.assertEqual(transforms.scale(0.25, 1, 0, 0, 1), 0.75)
|
||||
|
||||
self.assertEqual(transforms.scale(0, 0, 1, -4, 4), -4)
|
||||
self.assertEqual(transforms.scale(0.25, 0, 1, -4, 4), -2)
|
||||
self.assertEqual(transforms.scale(0.5, 0, 1, -4, 4), 0)
|
||||
self.assertEqual(transforms.scale(0.75, 0, 1, -4, 4), 2)
|
||||
self.assertEqual(transforms.scale(1, 0, 1, -4, 4), 4)
|
||||
|
||||
self.assertNan(transforms.scale(NAN, 0, 1, -4, 4))
|
||||
|
||||
def test_logistic(self):
|
||||
self.assertLess(transforms.logistic(-50), 0.000001)
|
||||
self.assertLess(transforms.logistic(-5), 0.01)
|
||||
self.assertEqual(transforms.logistic(0), 0.5)
|
||||
self.assertGreater(transforms.logistic(5), 0.99)
|
||||
self.assertGreater(transforms.logistic(50), 0.999999)
|
||||
self.assertAlmostEqual(transforms.logistic(0.8), math.tanh(0.4) / 2 + 0.5)
|
||||
self.assertNan(transforms.logistic(NAN))
|
||||
|
||||
def test_exp_scaled(self):
|
||||
t = transforms.NegExp(good=0, bad=1)
|
||||
self.assertNan(t([NAN])[0])
|
||||
self.assertAlmostEqual(t([0])[0], 1)
|
||||
self.assertAlmostEqual(t([1])[0], 0.1)
|
||||
self.assertLess(t([50])[0], 0.000001)
|
||||
|
||||
t = transforms.NegExp(good=10, bad=30)
|
||||
self.assertAlmostEqual(t([0])[0], 1)
|
||||
self.assertAlmostEqual(t([10])[0], 1)
|
||||
self.assertLess(t([3000])[0], 0.000001)
|
||||
|
||||
t = transforms.NegExp(good=30, bad=10)
|
||||
self.assertAlmostEqual(t([50])[0], 1)
|
||||
self.assertAlmostEqual(t([30])[0], 1)
|
||||
self.assertAlmostEqual(t([10])[0], 0.1)
|
||||
self.assertLess(t([-90])[0], 0.00001)
|
||||
|
||||
def test_neg(self):
|
||||
t = transforms.Neg()
|
||||
self.assertEqual(t([-5, -3, 0, 1, 4]), [5, 3, 0, -1, -4])
|
||||
self.assertNan(t([NAN])[0])
|
||||
|
||||
def test_abs(self):
|
||||
t = transforms.Abs()
|
||||
self.assertEqual(t([-5, -3, 0, 1, 4]), [5, 3, 0, 1, 4])
|
||||
self.assertNan(t([NAN])[0])
|
||||
|
||||
def test_pow(self):
|
||||
t = transforms.Pow(2)
|
||||
self.assertEqual(t([-5, -3, 0, 1, 4]), [25, 9, 0, 1, 16])
|
||||
self.assertNan(t([NAN])[0])
|
||||
|
||||
def test_log(self):
|
||||
t = transforms.Log()
|
||||
self.assertAlmostEqual(t([math.exp(2)])[0], 2, 4) # Low precision from eps.
|
||||
self.assertNan(t([NAN])[0])
|
||||
|
||||
def test_clipped_linear(self):
|
||||
t = transforms.ClippedLinear(good=0.1, bad=0.3)
|
||||
self.assertAlmostEqual(t([0])[0], 1)
|
||||
self.assertAlmostEqual(t([0.05])[0], 1)
|
||||
self.assertAlmostEqual(t([0.1])[0], 1)
|
||||
self.assertAlmostEqual(t([0.15])[0], 0.75)
|
||||
self.assertAlmostEqual(t([0.2])[0], 0.5)
|
||||
self.assertAlmostEqual(t([0.25])[0], 0.25)
|
||||
self.assertAlmostEqual(t([0.3])[0], 0)
|
||||
self.assertAlmostEqual(t([0.4])[0], 0)
|
||||
self.assertNan(t([NAN])[0])
|
||||
|
||||
t = transforms.ClippedLinear(good=1, bad=0.5)
|
||||
self.assertAlmostEqual(t([1.5])[0], 1)
|
||||
self.assertAlmostEqual(t([1])[0], 1)
|
||||
self.assertAlmostEqual(t([0.75])[0], 0.5)
|
||||
self.assertAlmostEqual(t([0.5])[0], 0)
|
||||
self.assertAlmostEqual(t([0.25])[0], 0)
|
||||
|
||||
def test_softplus(self):
|
||||
t = transforms.SoftPlus(good=0.1, bad=0.3)
|
||||
self.assertEqual(t([0])[0], 1)
|
||||
self.assertEqual(t([0.1])[0], 1)
|
||||
self.assertAlmostEqual(t([0.3])[0], 0.1)
|
||||
self.assertLess(t([0.5])[0], 0.01)
|
||||
self.assertNan(t([NAN])[0])
|
||||
|
||||
t = transforms.SoftPlus(good=1, bad=0.5)
|
||||
self.assertEqual(t([1.5])[0], 1)
|
||||
self.assertEqual(t([1])[0], 1)
|
||||
self.assertAlmostEqual(t([0.5])[0], 0.1)
|
||||
self.assertLess(t([0.1])[0], 0.01)
|
||||
|
||||
def test_sigmoid(self):
|
||||
t = transforms.Sigmoid(good=0.1, bad=0.3)
|
||||
self.assertGreater(t([0])[0], 0.99)
|
||||
self.assertAlmostEqual(t([0.1])[0], 0.95)
|
||||
self.assertAlmostEqual(t([0.2])[0], 0.5)
|
||||
self.assertAlmostEqual(t([0.3])[0], 0.05)
|
||||
self.assertLess(t([0.4])[0], 0.01)
|
||||
self.assertNan(t([NAN])[0])
|
||||
|
||||
t = transforms.Sigmoid(good=1, bad=0.5)
|
||||
self.assertGreater(t([1.5])[0], 0.99)
|
||||
self.assertAlmostEqual(t([1])[0], 0.95)
|
||||
self.assertAlmostEqual(t([0.75])[0], 0.5)
|
||||
self.assertAlmostEqual(t([0.5])[0], 0.05)
|
||||
self.assertLess(t([0.25])[0], 0.01)
|
||||
|
||||
def test_equal(self):
|
||||
t = transforms.Equal()
|
||||
self.assertEqual(t([0])[0], 1)
|
||||
self.assertEqual(t([0.001])[0], 0)
|
||||
self.assertNan(t([NAN])[0])
|
||||
|
||||
t = transforms.Equal(not_equal_val=0.5)
|
||||
self.assertEqual(t([0])[0], 1)
|
||||
self.assertEqual(t([0.001])[0], 0.5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
Reference in New Issue
Block a user