Internal change

PiperOrigin-RevId: 417034217
This commit is contained in:
Nimrod Gileadi
2021-12-17 18:27:35 +00:00
committed by Diego de Las Casas
parent fba48d1e44
commit 3f0d5ed1a0
41 changed files with 6676 additions and 0 deletions

59
avae/README.md Normal file
View File

@@ -0,0 +1,59 @@
# The Autoencoding Variational Autoencoder
This is the code for the models in NeurIPS Submission [AVAE](https://papers.nips.cc/paper/2020/file/ac10ff1941c540cd87c107330996f4f6-Paper.pdf)
Folder contains code to train AVAE model in JAX, and we will be uploading
evaluation setup soon.
Code files in the folder
- checkpointer.py: Checkpointing abstraction
- data_iterators.py: Datasets to be used
- decoders.py: VAE decoder network architectures
- encoders.py: VAE encoder network architectures
- kl.py: KL computation between 2 gaussians
- train.py: Function to train given ELBO, network and data
- train_main.py: Main file to train AVAE
- vae.py: VAE model defining various ELBOs
## Setup
To set up a Python3 virtual environment with the required dependencies, run:
```shell
python -m venv avae_env
source avae_env/bin/activate
pip install --upgrade pip
pip install -r avae/requirements.txt
```
## Running AVAE training
Following command will run AVAE training for ColorMnist dataset using MLP
network architectures.
```shell
python -m avae.train_main \
--dataset='color_mnist' \
--latent_dim=64 \
--checkpoint_dir='/tmp/avae_checkpoints' \
--checkpoint_filename='color_mnist_mlp_avae' \
--rho=0.975 \
--encoder='color_mnist_mlp_encoder' \
--decoder='color_mnist_mlp_decoder'
```
## References
### Citing our work
If you use that code for your research, please consider citing our paper:
```bibtex
@article{cemgil2020autoencoding,
title={The Autoencoding Variational Autoencoder},
author={Cemgil, Taylan and Ghaisas, Sumedh and Dvijotham, Krishnamurthy and Gowal, Sven and Kohli, Pushmeet},
journal={Advances in Neural Information Processing Systems},
volume={33},
year={2020}
}
```

BIN
avae/avae_image.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 164 KiB

88
avae/checkpointer.py Normal file
View File

@@ -0,0 +1,88 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpointing functionality."""
import os
from typing import Any, Mapping, Optional
from absl import logging
import dill
import jax
import jax.numpy as jnp
class Checkpointer:
"""A checkpoint saving and loading class."""
def __init__(self, checkpoint_dir: str, filename: str):
"""Class initializer.
Args:
checkpoint_dir: Checkpoint directory.
filename: Filename of checkpoint in checkpoint directory.
"""
self._checkpoint_dir = checkpoint_dir
if not os.path.isdir(self._checkpoint_dir):
os.mkdir(self._checkpoint_dir)
self._checkpoint_path = os.path.join(self._checkpoint_dir, filename)
def save_checkpoint(
self,
experiment_state: Mapping[str, jnp.ndarray],
opt_state: Mapping[str, jnp.ndarray],
step: int,
extra_checkpoint_info: Optional[Mapping[str, Any]] = None) -> None:
"""Save checkpoint with experiment state and step information.
Args:
experiment_state: Experiment params to be stored.
opt_state: Optimizer state to be stored.
step: Training iteration step.
extra_checkpoint_info: Extra information to be stored.
"""
if jax.host_id() != 0:
return
checkpoint_data = dict(
experiment_state=jax.tree_map(jax.device_get, experiment_state),
opt_state=jax.tree_map(jax.device_get, opt_state),
step=step)
if extra_checkpoint_info is not None:
for key in extra_checkpoint_info:
checkpoint_data[key] = extra_checkpoint_info[key]
with open(self._checkpoint_path, 'wb') as checkpoint_file:
dill.dump(checkpoint_data, checkpoint_file, protocol=2)
def load_checkpoint(
self) -> Optional[Mapping[str, Mapping[str, jnp.ndarray]]]:
"""Load and return checkpoint data.
Returns:
Loaded checkpoint if it exists else returns None.
"""
if os.path.exists(self._checkpoint_path):
with open(self._checkpoint_path, 'rb') as checkpoint_file:
checkpoint_data = dill.load(checkpoint_file)
logging.info('Loading checkpoint from %s, saved at step %d',
self._checkpoint_path, checkpoint_data['step'])
return checkpoint_data
else:
logging.warning('No pre-saved checkpoint found at %s',
self._checkpoint_path)
return None

97
avae/data_iterators.py Normal file
View File

@@ -0,0 +1,97 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset iterators Mnist, ColorMnist."""
import enum
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from avae import types
class Dataset(enum.Enum):
color_mnist = enum.auto()
class MnistDataIterator(object):
"""Mnist data iterator class.
Data is obtained as dataclass, types.LabelledData.
"""
def __init__(self, subset: str, batch_size: int):
"""Class initializer.
Args:
subset: Subset of dataset.
batch_size: Batch size of the returned dataset iterator.
"""
datasets = tfds.load('mnist')
train_dataset = datasets[subset]
def _map_fun(x):
return {'data': tf.cast(x['image'], tf.float32) / 255.0,
'label': x['label']}
train_dataset = train_dataset.map(_map_fun)
train_dataset = train_dataset.batch(batch_size,
drop_remainder=True).repeat()
self._iterator = iter(tfds.as_numpy(train_dataset))
self._batch_size = batch_size
def __iter__(self):
return self
def __next__(self) -> types.LabelledData:
return types.LabelledData(**next(self._iterator))
class ColorMnistDataIterator(MnistDataIterator):
"""Color Mnist data iterator.
Each ColorMnist image is of shape (28, 28, 3). ColorMnist digit can have 7
different colors by permution of RGB channels (turning on and off RGB
channels, except for all channels off permutation).
Data is obtained as dataclass, util_dataclasses.LabelledData.
"""
def __next__(self) -> types.LabelledData:
mnist_batch = next(self._iterator)
mnist_image = mnist_batch['data']
# Colors are supported by turning off and on RGB channels.
# Thus possible colors are
# [black (excluded), red, green, yellow, blue, magenta, cyan, white]
# color_id takes value from [1,8)
color_id = np.random.randint(7, size=self._batch_size) + 1
red_channel_bool = np.mod(color_id, 2)
red_channel_bool = jnp.reshape(red_channel_bool, [-1, 1, 1, 1])
blue_channel_bool = np.floor_divide(color_id, 4)
blue_channel_bool = jnp.reshape(blue_channel_bool, [-1, 1, 1, 1])
green_channel_bool = np.mod(np.floor_divide(color_id, 2), 2)
green_channel_bool = jnp.reshape(green_channel_bool, [-1, 1, 1, 1])
color_mnist_image = jnp.stack([
jnp.multiply(red_channel_bool, mnist_image),
jnp.multiply(blue_channel_bool, mnist_image),
jnp.multiply(green_channel_bool, mnist_image)], axis=3)
color_mnist_image = jnp.reshape(color_mnist_image, [-1, 28, 28, 3])
# Color id takes value [1, 8)
# Although to make classification code easier `color` label attached to data
# takes value in [0, 7) (by subtracting 1 from color id)
return types.LabelledData(
data=color_mnist_image, label=mnist_batch['label'])

88
avae/decoders.py Normal file
View File

@@ -0,0 +1,88 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decoder architectures to be used with VAE."""
import abc
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
class DecoderBase(hk.Module):
"""Base class for decoder network classes."""
def __init__(self, obs_var: float):
"""Class initializer.
Args:
obs_var: oversation variance of the dataset.
"""
super().__init__()
self._obs_var = obs_var
@abc.abstractmethod
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
"""Reconstruct from a given latent sample.
Args:
z: latent samples of shape (batch_size, latent_dim)
Returns:
Reconstruction with shape (batch_size, ...).
"""
def data_fidelity(
self,
input_data: jnp.ndarray,
recons: jnp.ndarray,
) -> jnp.ndarray:
"""Compute Data fidelity (recons loss) for given input and recons.
Args:
input_data: Input batch of shape (batch_size, ...).
recons: Reconstruction of the input data. An array with the same shape as
`input_data.data`.
Returns:
Computed data fidelity term across batch of data. An array of shape
`(batch_size,)`.
"""
error = (input_data - recons).reshape(input_data.shape[0], -1)
return -0.5 * jnp.sum(jnp.square(error), axis=1) / self._obs_var
class ColorMnistMLPDecoder(DecoderBase):
"""MLP decoder for Color Mnist."""
_hidden_units = (200, 200, 200, 200)
_image_dims = (28, 28, 3) # Dimensions of a single MNIST image.
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
"""Reconstruct with given latent sample.
Args:
z: latent samples of shape (batch_size, latent_dim)
Returns:
Reconstructions data of shape (batch_size, 28, 28, 3).
"""
out = z
for units in self._hidden_units:
out = hk.Linear(units)(out)
out = jax.nn.relu(out)
out = hk.Linear(np.product(self._image_dims))(out)
out = jax.nn.sigmoid(out)
return jnp.reshape(out, (-1,) + self._image_dims)

112
avae/encoders.py Normal file
View File

@@ -0,0 +1,112 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Encoder architectures to be used with VAE."""
import abc
from typing import Generic, TypeVar
import haiku as hk
import jax
import jax.numpy as jnp
from avae import types
_Params = TypeVar('_Params')
class EncoderBase(hk.Module, Generic[_Params]):
"""Abstract class for encoder architectures."""
def __init__(self, latent_dim: int):
"""Class initializer.
Args:
latent_dim: Latent dimensions of the model.
"""
super().__init__()
self._latent_dim = latent_dim
@abc.abstractmethod
def __call__(self, input_data: jnp.ndarray) -> _Params:
"""Return posterior distribution over latents.
Args:
input_data: Input batch of shape (batch_size, ...).
Returns:
Parameters of the posterior distribution over the latents.
"""
@abc.abstractmethod
def sample(self, posterior: _Params, key: jnp.ndarray) -> jnp.ndarray:
"""Sample from the given posterior distribution.
Args:
posterior: Parameters of posterior distribution over the latents.
key: Random number generator key.
Returns:
Sample from the posterior distribution over latents,
shape[batch_size, latent_dim]
"""
class ColorMnistMLPEncoder(EncoderBase[types.NormalParams]):
"""MLP encoder for ColorMnist."""
_hidden_units = (200, 200, 200, 200)
def __call__(
self, input_data: jnp.ndarray) -> types.NormalParams:
"""Return posterior distribution over latents.
Args:
input_data: Input batch of shape (batch_size, ...).
Returns:
Posterior distribution over the latents.
"""
out = hk.Flatten()(input_data)
for units in self._hidden_units:
out = hk.Linear(units)(out)
out = jax.nn.relu(out)
out = hk.Linear(2 * self._latent_dim)(out)
return _normal_params_from_logits(out)
def sample(
self,
posterior: types.NormalParams,
key: jnp.ndarray,
) -> jnp.ndarray:
"""Sample from the given normal posterior (mean, var) distribution.
Args:
posterior: Posterior over the latents.
key: Random number generator key.
Returns:
Sample from the posterior distribution over latents,
shape[batch_size, latent_dim]
"""
eps = jax.random.normal(
key, shape=(posterior.mean.shape[0], self._latent_dim))
return posterior.mean + eps * posterior.variance
def _normal_params_from_logits(
logits: jnp.ndarray) -> types.NormalParams:
"""Construct mean and variance of normal distribution from given logits."""
mean, log_variance = jnp.split(logits, 2, axis=1)
variance = jnp.exp(log_variance)
return types.NormalParams(mean=mean, variance=variance)

43
avae/kl.py Normal file
View File

@@ -0,0 +1,43 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Various KL implementations in JAX."""
import jax.numpy as jnp
def kl_p_with_uniform_normal(mean: jnp.ndarray,
variance: jnp.ndarray) -> jnp.ndarray:
r"""KL between p_dist with uniform normal prior.
Args:
mean: Mean of the gaussian distribution, shape (latent_dims,)
variance: Variance of the gaussian distribution, shape (latent_dims,)
Returns:
KL divergence KL(P||N(0, 1)) shape ()
"""
if len(variance.shape) == 2:
# If `variance` is a full covariance matrix
variance_trace = jnp.trace(variance)
_, ldet1 = jnp.linalg.slogdet(variance)
else:
variance_trace = jnp.sum(variance)
ldet1 = jnp.sum(jnp.log(variance))
mean_contribution = jnp.sum(jnp.square(mean))
res = -ldet1
res += variance_trace + mean_contribution - mean.shape[0]
return res * 0.5

8
avae/requirements.txt Normal file
View File

@@ -0,0 +1,8 @@
absl-py>0.9
dill==0.3.2
dm-haiku
jax
optax
tensorflow
tensorflow_datasets

31
avae/run.sh Executable file
View File

@@ -0,0 +1,31 @@
#!/bin/bash
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
python -m venv avae_env
source avae_env/bin/activate
pip install --upgrade pip
pip install -r avae/requirements.txt
# Start training AVAE model.
python -m avae.train_main \
--dataset='color_mnist' \
--latent_dim=64 \
--checkpoint_dir='/tmp/avae_checkpoints' \
--checkpoint_filename='color_mnist_mlp_avae' \
--rho=0.975 \
--encoder=color_mnist_mlp_encoder \
--decoder=color_mnist_mlp_decoder \
--iterations=1

117
avae/train.py Normal file
View File

@@ -0,0 +1,117 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VAE style training."""
from typing import Any, Callable, Iterator, Sequence, Mapping, Tuple, Optional
import haiku as hk
import jax
import jax.numpy as jnp
import optax
from avae import checkpointer
from avae import types
def train(
train_data_iterator: Iterator[types.LabelledData],
test_data_iterator: Iterator[types.LabelledData],
elbo_fun: hk.Transformed,
learning_rate: float,
checkpoint_dir: str,
checkpoint_filename: str,
checkpoint_every: int,
test_every: int,
iterations: int,
rng_seed: int,
test_functions: Optional[Sequence[Callable[[Mapping[str, jnp.ndarray]],
Tuple[str, float]]]] = None,
extra_checkpoint_info: Optional[Mapping[str, Any]] = None):
"""Train VAE with given data iterator and elbo definition.
Args:
train_data_iterator: Iterator of batched training data.
test_data_iterator: Iterator of batched testing data.
elbo_fun: Haiku transfomed function returning elbo.
learning_rate: Learning rate to be used with optimizer.
checkpoint_dir: Path of the checkpoint directory.
checkpoint_filename: Filename of the checkpoint.
checkpoint_every: Checkpoint every N iterations.
test_every: Test and log results every N iterations.
iterations: Number of training iterations to perform.
rng_seed: Seed for random number generator.
test_functions: Test function iterable, each function takes test data and
outputs extra info to print at test and log time.
extra_checkpoint_info: Extra info to put inside saved checkpoint.
"""
rng_seq = hk.PRNGSequence(jax.random.PRNGKey(rng_seed))
opt_init, opt_update = optax.chain(
# Set the parameters of Adam. Note the learning_rate is not here.
optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
# Put a minus sign to *minimise* the loss.
optax.scale(-learning_rate))
@jax.jit
def loss(params, key, data):
elbo_outputs = elbo_fun.apply(params, key, data)
return -jnp.mean(elbo_outputs.elbo)
@jax.jit
def loss_test(params, key, data):
elbo_output = elbo_fun.apply(params, key, data)
return (-jnp.mean(elbo_output.elbo), jnp.mean(elbo_output.data_fidelity),
jnp.mean(elbo_output.kl))
@jax.jit
def update_step(params, key, data, opt_state):
grads = jax.grad(loss, has_aux=False)(params, key, data)
updates, opt_state = opt_update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state
exp_checkpointer = checkpointer.Checkpointer(
checkpoint_dir, checkpoint_filename)
experiment_data = exp_checkpointer.load_checkpoint()
if experiment_data is not None:
start = experiment_data['step']
params = experiment_data['experiment_state']
opt_state = experiment_data['opt_state']
else:
start = 0
params = elbo_fun.init(
next(rng_seq), next(train_data_iterator).data)
opt_state = opt_init(params)
for step in range(start, iterations, 1):
if step % test_every == 0:
test_loss, ll, kl = loss_test(params, next(rng_seq),
next(test_data_iterator).data)
output_message = (f'Step {step} elbo {-test_loss:0.2f} LL {ll:0.2f} '
f'KL {kl:0.2f}')
if test_functions:
for test_function in test_functions:
name, test_output = test_function(params)
output_message += f' {name}: {test_output:0.2f}'
print(output_message)
params, opt_state = update_step(params, next(rng_seq),
next(train_data_iterator).data, opt_state)
if step % checkpoint_every == 0:
exp_checkpointer.save_checkpoint(
params, opt_state, step, extra_checkpoint_info)

127
avae/train_main.py Normal file
View File

@@ -0,0 +1,127 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main file for training VAE/AVAE."""
import enum
from absl import app
from absl import flags
import haiku as hk
from avae import data_iterators
from avae import decoders
from avae import encoders
from avae import train
from avae import vae
class Model(enum.Enum):
vae = enum.auto()
avae = enum.auto()
class EncoderArch(enum.Enum):
color_mnist_mlp_encoder = 'ColorMnistMLPEncoder'
class DecoderArch(enum.Enum):
color_mnist_mlp_decoder = 'ColorMnistMLPDecoder'
_DATASET = flags.DEFINE_enum_class(
'dataset', data_iterators.Dataset.color_mnist, data_iterators.Dataset,
'Dataset to train on')
_LATENT_DIM = flags.DEFINE_integer('latent_dim', 32,
'Number of latent dimensions.')
_TRAIN_BATCH_SIZE = flags.DEFINE_integer('train_batch_size', 64,
'Train batch size.')
_TEST_BATCH_SIZE = flags.DEFINE_integer('test_batch_size', 64,
'Testing batch size.')
_TEST_EVERY = flags.DEFINE_integer('test_every', 1000,
'Test every N iterations.')
_ITERATIONS = flags.DEFINE_integer('iterations', 102000,
'Number of training iterations.')
_OBS_VAR = flags.DEFINE_float('obs_var', 0.5,
'Observation variance of the data. (Default 0.5)')
_MODEL = flags.DEFINE_enum_class('model', Model.avae, Model,
'Model used for training.')
_RHO = flags.DEFINE_float('rho', 0.8, 'Rho parameter used with AVAE or SE.')
_LEARNING_RATE = flags.DEFINE_float('learning_rate', 1e-4, 'Learning rate.')
_RNG_SEED = flags.DEFINE_integer('rng_seed', 0,
'Seed for random number generator.')
_CHECKPOINT_DIR = flags.DEFINE_string('checkpoint_dir', '/tmp/',
'Directory for checkpointing.')
_CHECKPOINT_FILENAME = flags.DEFINE_string(
'checkpoint_filename', 'color_mnist_avae_mlp', 'Checkpoint filename.')
_CHECKPOINT_EVERY = flags.DEFINE_integer(
'checkpoint_every', 1000, 'Checkpoint every N steps.')
_ENCODER = flags.DEFINE_enum_class(
'encoder', EncoderArch.color_mnist_mlp_encoder, EncoderArch,
'Encoder class name.')
_DECODER = flags.DEFINE_enum_class(
'decoder', DecoderArch.color_mnist_mlp_decoder, DecoderArch,
'Decoder class name.')
def main(_):
if _DATASET.value is data_iterators.Dataset.color_mnist:
train_data_iterator = iter(
data_iterators.ColorMnistDataIterator('train', _TRAIN_BATCH_SIZE.value))
test_data_iterator = iter(
data_iterators.ColorMnistDataIterator('test', _TEST_BATCH_SIZE.value))
def _elbo_fun(input_data):
if _ENCODER.value is EncoderArch.color_mnist_mlp_encoder:
encoder = encoders.ColorMnistMLPEncoder(_LATENT_DIM.value)
if _DECODER.value is DecoderArch.color_mnist_mlp_decoder:
decoder = decoders.ColorMnistMLPDecoder(_OBS_VAR.value)
vae_obj = vae.VAE(encoder, decoder, _RHO.value)
if _MODEL.value is Model.vae:
return vae_obj.vae_elbo(input_data, hk.next_rng_key())
else:
return vae_obj.avae_elbo(input_data, hk.next_rng_key())
elbo_fun = hk.transform(_elbo_fun)
extra_checkpoint_info = {
'dataset': _DATASET.value.name,
'encoder': _ENCODER.value.name,
'decoder': _DECODER.value.name,
'obs_var': _OBS_VAR.value,
'rho': _RHO.value,
'latent_dim': _LATENT_DIM.value,
}
train.train(
train_data_iterator=train_data_iterator,
test_data_iterator=test_data_iterator,
elbo_fun=elbo_fun,
learning_rate=_LEARNING_RATE.value,
checkpoint_dir=_CHECKPOINT_DIR.value,
checkpoint_filename=_CHECKPOINT_FILENAME.value,
checkpoint_every=_CHECKPOINT_EVERY.value,
test_every=_TEST_EVERY.value,
iterations=_ITERATIONS.value,
rng_seed=_RNG_SEED.value,
extra_checkpoint_info=extra_checkpoint_info)
if __name__ == '__main__':
app.run(main)

53
avae/types.py Normal file
View File

@@ -0,0 +1,53 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Useful dataclasses types used across the code."""
from typing import Optional
import dataclasses
import jax.numpy as jnp
import numpy as np
@dataclasses.dataclass(frozen=True)
class ELBOOutputs:
elbo: jnp.ndarray
data_fidelity: jnp.ndarray
kl: jnp.ndarray
@dataclasses.dataclass(frozen=True)
class LabelledData:
"""A batch of labelled examples.
Attributes:
data: Array of shape (batch_size, ...).
label: Array of shape (batch_size, ...).
"""
data: np.ndarray
label: Optional[np.ndarray]
@dataclasses.dataclass(frozen=True)
class NormalParams:
"""Parameters of a normal distribution.
Attributes:
mean: Array of shape (batch_size, latent_dim).
variance: Array of shape (batch_size, latent_dim).
"""
mean: jnp.ndarray
variance: jnp.ndarray

133
avae/vae.py Normal file
View File

@@ -0,0 +1,133 @@
# Copyright 2020 DeepMind Technologies Limited.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Standard VAE class."""
from typing import Optional
import jax
import jax.numpy as jnp
from avae import decoders
from avae import encoders
from avae import kl
from avae import types
class VAE:
"""VAE class.
This class defines the ELBO used in training VAE models. It also adds function
for forward passing data through VAE.
"""
def __init__(self, encoder: encoders.EncoderBase,
decoder: decoders.DecoderBase, rho: Optional[float] = None):
"""Class initializer.
Args:
encoder: Encoder network architecture.
decoder: Decoder network architecture.
rho: Rho parameter used in AVAE training.
"""
self._encoder = encoder
self._decoder = decoder
self._rho = rho
def vae_elbo(
self, input_data: jnp.ndarray,
key: jnp.ndarray) -> types.ELBOOutputs:
"""ELBO for training VAE.
Args:
input_data: Input batch of shape (batch_size, ...).
key: Key for random number generator.
Returns:
Computed VAE Elbo as type util_dataclasses.ELBOOutputs
"""
posterior = self._encoder(input_data)
samples = self._encoder.sample(posterior, key)
kls = jax.vmap(kl.kl_p_with_uniform_normal, [0])(
posterior.mean, posterior.variance)
recons = self._decoder(samples)
data_fidelity = self._decoder.data_fidelity(input_data, recons)
elbo = data_fidelity - kls
return types.ELBOOutputs(elbo, data_fidelity, kls)
def avae_elbo(
self, input_data: jnp.ndarray,
key: jnp.ndarray) -> types.ELBOOutputs:
"""ELBO for training AVAE model.
Args:
input_data: Input batch of shape (batch_size, ...).
key: Key for random number generator.
Returns:
Computed AVAE Elbo in nested tuple (Elbo, (data_fidelity, KL)). All arrays
have batch dimension intact.
"""
aux_images = jax.lax.stop_gradient(self(input_data, key))
posterior = self._encoder(input_data)
samples = self._encoder.sample(posterior, key)
kls = jax.vmap(kl.kl_p_with_uniform_normal, [0, 0])(
posterior.mean, posterior.variance)
recons = self._decoder(samples)
data_fidelity = self._decoder.data_fidelity(input_data, recons)
elbo = data_fidelity - kls
aux_posterior = self._encoder(aux_images)
latent_mean = posterior.mean
latent_var = posterior.variance
aux_latent_mean = aux_posterior.mean
aux_latent_var = aux_posterior.variance
latent_dim = latent_mean.shape[1]
def _reduce(x):
return jnp.mean(jnp.sum(x, axis=1))
# Computation of <log p(Z_aux | Z)>.
expected_log_conditional = (
aux_latent_var + jnp.square(self._rho) * latent_var +
jnp.square(aux_latent_mean - self._rho * latent_mean))
expected_log_conditional = _reduce(expected_log_conditional)
expected_log_conditional /= 2.0 * (1.0 - jnp.square(self._rho))
expected_log_conditional = (latent_dim *
jnp.log(1.0 / (2 * jnp.pi)) -
expected_log_conditional)
elbo += expected_log_conditional
# Entropy of Z_aux
elbo += _reduce(0.5 * jnp.log(2 * jnp.pi * jnp.e * aux_latent_var))
return types.ELBOOutputs(elbo, data_fidelity, kls)
def __call__(
self, input_data: jnp.ndarray,
key: jnp.ndarray) -> jnp.ndarray:
"""Reconstruction of the input data.
Args:
input_data: Input batch of shape (batch_size, ...).
key: Key for random number generator.
Returns:
Reconstruction of the input data as jnp.ndarray of shape
[batch_dim, observation_dims].
"""
posterior = self._encoder(input_data)
samples = self._encoder.sample(posterior, key)
recons = self._decoder(samples)
return recons

83
fusion_tcv/README.md Normal file
View File

@@ -0,0 +1,83 @@
# TCV Fusion Control Objectives
This code release contains the rewards, control targets, noise model and
parameter variation for the paper *Magnetic control of tokamak plasmas through
deep reinforcement learning*, published in Nature in ... 2021.
## Disclaimer
This release is useful for understanding the details of specific elements of the
learning architecture, however it does not contain the simulator
([FGE](https://infoscience.epfl.ch/record/283775), part of
[LIUQE](https://www.epfl.ch/research/domains/swiss-plasma-center/liuqe-suite/)),
the trained/exported control policies, nor the agent training infrastructure.
This release is useful for replicating our results which can be done by
assembling all the components, many of which are open source elsewhere. Please
see the "Code Availability" statement in the paper for more information.
The learning algorithm we used is
[MPO](https://arxiv.org/abs/1812.02256), which has an open source
[reference implementation](https://github.com/deepmind/acme) in
[Acme](https://arxiv.org/abs/2006.00979). Additionally, the open source software
libraries [launchpad](https://arxiv.org/abs/2106.04516)
([code](https://github.com/deepmind/launchpad)),
[dm_env](https://github.com/deepmind/dm_env),
[sonnet](https://github.com/deepmind/sonnet),
[tensorflow](https://arxiv.org/abs/1603.04467)
([code](https://www.tensorflow.org/)) and
[reverb](https://arxiv.org/pdf/2102.04736.pdf)
([code](https://github.com/deepmind/reverb)) were used.
FGE and LIUQE are available on request from the
[Swiss Plasma Center](https://www.epfl.ch/research/domains/swiss-plasma-center/)
at [EPFL](https://www.epfl.ch/en/)
(email [Federico Felici](mailto:federico.felici@epfl.ch)), subject to agreement.
## Objectives used in published TCV experiments
Take a look at `rewards_used.py` and `references.py` for the rewards and control
targets used in the paper.
To print the actual control targets, run:
```sh
$ python3 -m fusion_tcv.references_main --refs=snowflake
```
Make sure to install the dependencies:
`pip install -r fusion_tcv/requirements.txt`.
## Overview of files
* `agent.py`: An interface for what a real agent might look like. This mainly
exists so the run loop builds.
* `combiners.py`: Defined combiners that combine multiple values into one.
Useful for creating a scalar reward from a set of reward components.
* `environment.py`: Augments the FGE simulator to make it an RL environment
(parameter variation, reward computation, etc.).
* `experiments.py`: The environment definitions published in the paper.
* `fge_octave.py`: An interface for the simulator. Given that FGE isn't open
source, this is a just a sketch of how you might use it.
* `fge_state.py`: A python interface to the simulator state. Given that FGE
isn't open source, this is a sketch and returns fake data.
* `named_array.py`: Makes it easier to interact with numpy arrays by name.
Useful for referring to parts of the control targets/references.
* `noise.py`: Adds action and observation noise.
* `param_variation.py`: Defines the physics parameters used by the simulation.
* `ref_gen.py`: The tools to generate control targets per time step.
* `references.py`: The control targets used in the experiments.
* `references_main.py`: Runnable script to output the references to the command
line.
* `rewards.py`: Computes all of the reward components and combines them together
to create a single scalar reward for the environment.
* `rewards_used.py`: The actual reward definitions used in the experiments.
* `run_loop.py`: An example of interacting with the environment to generate a
trajectory to send to the replay buffer. This code is notional as a complete
version would require an Agent and a Simulator implementation.
* `targets.py`: The reward components that pull data from the control targets
and the simulator state for generating rewards. This depends on FGE, so
cannot be run.
* `tcv_common.py`: Defines the physical attributes of the TCV fusion reactor,
and how to interact with it.
* `terminations.py`: Defines when the simulation should stop.
* `trajectory.py`: Stores the history of the episode.
* `transforms.py`: Turns error values into normalized values.

41
fusion_tcv/agent.py Normal file
View File

@@ -0,0 +1,41 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An agent interface for interacting with the environment."""
import abc
import dm_env
import numpy as np
from fusion_tcv import tcv_common
class AbstractAgent(abc.ABC):
"""Agent base class."""
def reset(self):
"""Reset to the initial state."""
@abc.abstractmethod
def step(self, timestep: dm_env.TimeStep) -> np.ndarray:
"""Return the action given the current observations."""
class ZeroAgent(AbstractAgent):
"""An agent that always returns "zero" actions."""
def step(self, timestep: dm_env.TimeStep) -> np.ndarray:
del timestep
return np.zeros(tcv_common.action_spec().shape,
tcv_common.action_spec().dtype)

220
fusion_tcv/combiners.py Normal file
View File

@@ -0,0 +1,220 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reward combiners."""
import abc
import math
from typing import List, Optional, Tuple
import dataclasses
import numpy as np
from scipy import special
from fusion_tcv import targets
class AbstractCombiner(targets.AbstractTarget):
"""Combines a set of rewards, possibly weighted."""
@abc.abstractmethod
def __call__(self, values: List[float],
weights: Optional[List[float]] = None) -> List[float]:
"""Combines a set of rewards, possibly weighted."""
@property
def outputs(self) -> int:
"""All combiners return exactly one value, even if it's NaN."""
return 1
@staticmethod
def _clean_values_weights(
values: List[float],
weights: Optional[List[float]] = None) -> Tuple[List[float], List[float]]:
"""Validate the values and weights, and if no weights, return equal."""
if weights is None:
weights = [1] * len(values)
else:
if len(values) != len(weights):
raise ValueError("Number of weights don't match values. "
f"values: {len(values)}, weights: {len(weights)}")
for w in weights:
if w < 0:
raise ValueError(f"Weights must be >=0: {w}")
new_values_weights = [(v, w) for v, w in zip(values, weights)
if not np.isnan(v) and w > 0]
return tuple(zip(*new_values_weights)) if new_values_weights else ([], [])
class Mean(AbstractCombiner):
"""Take the weighted mean of the values.
Ignores NaNs and values with weight 0.
"""
def __call__(self, values: List[float],
weights: Optional[List[float]] = None) -> List[float]:
values, weights = self._clean_values_weights(values, weights)
if not values:
return [float("nan")]
return [sum(r * w for r, w in zip(values, weights)) / sum(weights)]
def _multiply(values, weights, mean):
"""Multiplies the values taking care to validate the weights.
Defines 0^0 = 1 so a reward with no weight is "off" even if the value is 0.
Args:
values: The reward values.
weights: The reward weights.
mean: If true, divides by the sum of the weights (computes the geometric
mean).
Returns:
Product of v^w across the components.
"""
# If weight and value are both zero, set the value to 1 so that 0^0 = 1.
values = [1 if (v == 0 and w == 0) else v for (v, w) in zip(values, weights)]
if any(v == 0 for v in values):
return [0]
den = sum(weights) if mean else 1
return [math.exp(sum(np.log(values) * weights) / den)]
class Multiply(AbstractCombiner):
"""Combine by multiplying the (weighted) values together.
This is the same as Geometric mean, but without the n^th root taken at the
end. This means doing poorly on several rewards compounds, rather than
averages. As such it likely only makes sense after the non-linearities, ie
where the values are in the 0-1 range, otherwise it'll cause them to increase.
This is even harsher than Min or SmoothMax(-inf).
Ignores NaNs and values with weight 0.
"""
def __call__(self, values: List[float],
weights: Optional[List[float]] = None) -> List[float]:
values, weights = self._clean_values_weights(values, weights)
if not values:
return [float("nan")]
return _multiply(values, weights, mean=False)
class GeometricMean(AbstractCombiner):
"""Take the weighted geometric mean of the values.
Pushes values towards 0, so likely only makes sense after the non-linear
transforms.
Ignores NaNs and values with weight 0.
"""
def __call__(self, values: List[float],
weights: Optional[List[float]] = None) -> List[float]:
values, weights = self._clean_values_weights(values, weights)
if not values:
return [float("nan")]
return _multiply(values, weights, mean=True)
class Min(AbstractCombiner):
"""Take the min of the values. Ignores NaNs and values with weight 0."""
def __call__(self, values: List[float],
weights: Optional[List[float]] = None) -> List[float]:
values, _ = self._clean_values_weights(values, weights)
if not values:
return [float("nan")]
return [min(values)]
class Max(AbstractCombiner):
"""Take the max of the values. Ignores NaNs and values with weight 0."""
def __call__(self, values: List[float],
weights: Optional[List[float]] = None) -> List[float]:
values, _ = self._clean_values_weights(values, weights)
if not values:
return [float("nan")]
return [max(values)]
@dataclasses.dataclass(frozen=True)
class LNorm(AbstractCombiner):
"""Take the l-norm of the values.
Reasonable norm values (assuming normalized):
- 1: avg of the values
- 2: euclidean distance metric
- inf: max value
Values in between go between the average and max. As the l-norm goes up, the
result gets closer to the max.
Normalized means dividing by the max possible distance, such that the units
still make sense.
This likely only makes sense before the non-linear transforms. SmoothMax is
similar but more flexible and understandable.
Ignores NaNs and values with weight 0.
"""
norm: float
normalized: bool = True
def __call__(self, values: List[float],
weights: Optional[List[float]] = None) -> List[float]:
values, _ = self._clean_values_weights(values, weights)
if not values:
return [float("nan")]
lnorm = np.linalg.norm(values, ord=self.norm)
if self.normalized:
lnorm /= np.linalg.norm(np.ones(len(values)), ord=self.norm)
return [float(lnorm)]
@dataclasses.dataclass(frozen=True)
class SmoothMax(AbstractCombiner):
"""Combines component rewards using a smooth maximum.
https://en.wikipedia.org/wiki/Smooth_maximum
alpha is the exponent for the smooth max.
- alpha -> inf: returns the maximum
- alpha == 0: returns the weighted average
- alpha -> -inf: returns the minimum
alpha in between returns values in between.
Since this varies between min, mean and max, it keeps the existing scale.
Alpha >= 0 make sense before converting to 0-1, alpha <= 0 make sense after.
Ignores NaNs and values with weight 0.
"""
alpha: float
def __call__(self, values: List[float],
weights: Optional[List[float]] = None) -> List[float]:
values, weights = self._clean_values_weights(values, weights)
if not values:
return [float("nan")]
if math.isinf(self.alpha):
return [max(values) if self.alpha > 0 else min(values)]
# Compute weights in a numerically-friendly way.
log_soft_weights = [np.log(w) + c * self.alpha
for w, c in zip(weights, values)]
log_soft_weights -= special.logsumexp(log_soft_weights)
soft_weights = np.exp(log_soft_weights)
return Mean()(values, soft_weights) # weighted mean

View File

@@ -0,0 +1,157 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for combiners."""
import math
from absl.testing import absltest
from fusion_tcv import combiners
NAN = float("nan")
class CombinersTest(absltest.TestCase):
def assertNan(self, value):
self.assertLen(value, 1)
self.assertTrue(math.isnan(value[0]))
def test_errors(self):
c = combiners.Mean()
with self.assertRaises(ValueError):
c([0, 1], [1])
with self.assertRaises(ValueError):
c([0, 1], [1, 2, 3])
with self.assertRaises(ValueError):
c([0, 1], [-1, 2])
def test_mean(self):
c = combiners.Mean()
self.assertEqual(c([0, 2, 4]), [2])
self.assertEqual(c([0, 0.5, 1]), [0.5])
self.assertEqual(c([0, 0.5, 1], [0, 0, 1]), [1])
self.assertEqual(c([0, 1], [1, 3]), [0.75])
self.assertEqual(c([0, NAN], [1, 3]), [0])
self.assertNan(c([NAN, NAN], [1, 3]))
def test_geometric_mean(self):
c = combiners.GeometricMean()
self.assertEqual(c([0.5, 0]), [0])
self.assertEqual(c([0.3]), [0.3])
self.assertEqual(c([4, 4]), [4])
self.assertEqual(c([0.5, 0.5]), [0.5])
self.assertEqual(c([0.5, 0.5], [1, 3]), [0.5])
self.assertEqual(c([0.5, 1], [1, 2]), [0.5**(1/3)])
self.assertEqual(c([0.5, 1], [2, 1]), [0.5**(2/3)])
self.assertEqual(c([0.5, 0], [2, 0]), [0.5])
self.assertEqual(c([0.5, 0, 0], [2, 1, 0]), [0])
self.assertEqual(c([0.5, NAN, 0], [2, 1, 0]), [0.5])
self.assertNan(c([NAN, NAN], [1, 3]))
def test_multiply(self):
c = combiners.Multiply()
self.assertEqual(c([0.5, 0]), [0])
self.assertEqual(c([0.3]), [0.3])
self.assertEqual(c([0.5, 0.5]), [0.25])
self.assertEqual(c([0.5, 0.5], [1, 3]), [0.0625])
self.assertEqual(c([0.5, 1], [1, 2]), [0.5])
self.assertEqual(c([0.5, 1], [2, 1]), [0.25])
self.assertEqual(c([0.5, 0], [2, 0]), [0.25])
self.assertEqual(c([0.5, 0, 0], [2, 1, 0]), [0])
self.assertEqual(c([0.5, NAN], [1, 1]), [0.5])
self.assertNan(c([NAN, NAN], [1, 3]))
def test_min(self):
c = combiners.Min()
self.assertEqual(c([0, 1]), [0])
self.assertEqual(c([0.5, 1]), [0.5])
self.assertEqual(c([1, 0.75]), [0.75])
self.assertEqual(c([1, 3]), [1])
self.assertEqual(c([1, 1, 3], [0, 1, 1]), [1])
self.assertEqual(c([NAN, 3]), [3])
self.assertNan(c([NAN, NAN], [1, 3]))
def test_max(self):
c = combiners.Max()
self.assertEqual(c([0, 1]), [1])
self.assertEqual(c([0.5, 1]), [1])
self.assertEqual(c([1, 0.75]), [1])
self.assertEqual(c([1, 3]), [3])
self.assertEqual(c([1, 1, 3], [0, 1, 1]), [3])
self.assertEqual(c([NAN, 3]), [3])
self.assertNan(c([NAN, NAN], [1, 3]))
def test_lnorm(self):
c = combiners.LNorm(1)
self.assertEqual(c([0, 2, 4]), [2])
self.assertEqual(c([0, 0.5, 1]), [0.5])
self.assertEqual(c([3, 4]), [7 / 2])
self.assertEqual(c([0, 2, 4], [1, 1, 0]), [1])
self.assertEqual(c([0, 2, NAN]), [1])
self.assertNan(c([NAN, NAN], [1, 3]))
c = combiners.LNorm(1, normalized=False)
self.assertEqual(c([0, 2, 4]), [6])
self.assertEqual(c([0, 0.5, 1]), [1.5])
self.assertEqual(c([3, 4]), [7])
c = combiners.LNorm(2)
self.assertEqual(c([3, 4]), [5 / 2**0.5])
c = combiners.LNorm(2, normalized=False)
self.assertEqual(c([3, 4]), [5])
c = combiners.LNorm(math.inf)
self.assertAlmostEqual(c([3, 4])[0], 4)
c = combiners.LNorm(math.inf, normalized=False)
self.assertAlmostEqual(c([3, 4])[0], 4)
def test_smoothmax(self):
# Max
c = combiners.SmoothMax(math.inf)
self.assertEqual(c([0, 1]), [1])
self.assertEqual(c([0.5, 1]), [1])
self.assertEqual(c([1, 0.75]), [1])
self.assertEqual(c([1, 3]), [3])
# Smooth Max
c = combiners.SmoothMax(1)
self.assertAlmostEqual(c([0, 1])[0], 0.7310585786300049)
# Mean
c = combiners.SmoothMax(0)
self.assertEqual(c([0, 2, 4]), [2])
self.assertEqual(c([0, 0.5, 1]), [0.5])
self.assertEqual(c([0, 0.5, 1], [0, 0, 1]), [1])
self.assertEqual(c([0, 2, NAN]), [1])
self.assertEqual(c([0, 2, NAN], [0, 1, 1]), [2])
self.assertAlmostEqual(c([0, 1], [1, 3])[0], 0.75)
self.assertNan(c([NAN, NAN], [1, 3]))
# Smooth Min
c = combiners.SmoothMax(-1)
self.assertEqual(c([0, 1])[0], 0.2689414213699951)
# Min
c = combiners.SmoothMax(-math.inf)
self.assertEqual(c([0, 1]), [0])
self.assertEqual(c([0.5, 1]), [0.5])
self.assertEqual(c([1, 0.75]), [0.75])
self.assertEqual(c([1, 3]), [1])
if __name__ == "__main__":
absltest.main()

160
fusion_tcv/environment.py Normal file
View File

@@ -0,0 +1,160 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Environment API for FGE simulator."""
from typing import Dict, List, Optional
import dm_env
from dm_env import auto_reset_environment
from dm_env import specs
import numpy as np
from fusion_tcv import fge_octave
from fusion_tcv import fge_state
from fusion_tcv import named_array
from fusion_tcv import noise
from fusion_tcv import param_variation
from fusion_tcv import ref_gen
from fusion_tcv import rewards
from fusion_tcv import tcv_common
from fusion_tcv import terminations
# Re-export as fge_octave should be an implementation detail.
ShotCondition = fge_octave.ShotCondition
class Environment(auto_reset_environment.AutoResetEnvironment):
"""An environment using the FGE Solver.
The simulator will return a flux map, which is the environment's hidden state,
and some flux measurements, which will be used as observations. The actions
represent current levels that are passed to the simulator for the next
flux calculation.
"""
def __init__(
self,
shot_condition: ShotCondition,
reward: rewards.AbstractReward,
reference_generator: ref_gen.AbstractReferenceGenerator,
max_episode_length: int = 10000,
termination: Optional[terminations.Abstract] = None,
obs_act_noise: Optional[noise.Noise] = None,
power_supply_delays: Optional[Dict[str, List[float]]] = None,
param_generator: Optional[param_variation.ParamGenerator] = None):
"""Initializes an Environment instance.
Args:
shot_condition: A ShotCondition, specifying shot number and time. This
specifies the machine geometry (eg with or without the baffles), and the
initial measurements, voltages, current and plasma state.
reward: Function to generate a reward term.
reference_generator: Generator for the signal to send to references.
max_episode_length: Maximum number of steps before episode is truncated
and restarted.
termination: Decide if the state should be considered a termination.
obs_act_noise: Type for setting the observation and action noise. If noise
is set to None then the default noise level is used.
power_supply_delays: A dict with power supply delays (in seconds), keys
are coil type labels ('E', 'F', 'G', 'OH'). `None` means default delays.
param_generator: Generator for Liuqe parameter settings. If None then
the default settings are used.
"""
super().__init__()
if power_supply_delays is None:
power_supply_delays = tcv_common.TCV_ACTION_DELAYS
self._simulator = fge_octave.FGESimulatorOctave(
shot_condition=shot_condition,
power_supply_delays=power_supply_delays)
self._reward = reward
self._reference_generator = reference_generator
self._max_episode_length = max_episode_length
self._termination = (termination if termination is not None else
terminations.CURRENT_OH_IP)
self._noise = (obs_act_noise if obs_act_noise is not None else
noise.Noise.use_default_noise())
self._param_generator = (param_generator if param_generator is not None else
param_variation.ParamGenerator())
self._params = None
self._step_counter = 0
self._last_observation = None
def observation_spec(self):
"""Defines the observations provided by the environment."""
return tcv_common.observation_spec()
def action_spec(self) -> specs.BoundedArray:
"""Defines the actions that should be provided to `step`."""
return tcv_common.action_spec()
def _reset(self) -> dm_env.TimeStep:
"""Starts a new episode."""
self._step_counter = 0
self._params = self._param_generator.generate()
state = self._simulator.reset(self._params)
references = self._reference_generator.reset()
zero_act = np.zeros(self.action_spec().shape,
dtype=self.action_spec().dtype)
self._last_observation = self._extract_observation(
state, references, zero_act)
return dm_env.restart(self._last_observation)
def _simulator_voltages_from_voltages(self, voltages):
voltage_simulator = np.copy(voltages)
if self._params.psu_voltage_offset is not None:
for coil, offset in self._params.psu_voltage_offset.items():
voltage_simulator[tcv_common.TCV_ACTION_INDICES[coil]] += offset
voltage_simulator = np.clip(
voltage_simulator,
self.action_spec().minimum,
self.action_spec().maximum)
g_coil = tcv_common.TCV_ACTION_RANGES.index("G")
if abs(voltage_simulator[g_coil]) < tcv_common.ENV_G_COIL_DEADBAND:
voltage_simulator[g_coil] = 0
return voltage_simulator
def _step(self, action: np.ndarray) -> dm_env.TimeStep:
"""Does one step within TCV."""
voltages = self._noise.add_action_noise(action)
voltage_simulator = self._simulator_voltages_from_voltages(voltages)
try:
state = self._simulator.step(voltage_simulator)
except (fge_state.InvalidSolutionError,
fge_state.StopSignalException):
return dm_env.termination(
self._reward.terminal_reward(), self._last_observation)
references = self._reference_generator.step()
self._last_observation = self._extract_observation(
state, references, action)
term = self._termination.terminate(state)
if term:
return dm_env.termination(
self._reward.terminal_reward(), self._last_observation)
reward, _ = self._reward.reward(voltages, state, references)
self._step_counter += 1
if self._step_counter >= self._max_episode_length:
return dm_env.truncation(reward, self._last_observation)
return dm_env.transition(reward, self._last_observation)
def _extract_observation(
self, state: fge_state.FGEState,
references: named_array.NamedArray,
action: np.ndarray) -> Dict[str, np.ndarray]:
return {
"references": references.array,
"measurements": self._noise.add_measurement_noise(
state.get_observation_vector()),
"last_action": action,
}

72
fusion_tcv/experiments.py Normal file
View File

@@ -0,0 +1,72 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The environment definitions used for our experiments."""
from fusion_tcv import environment
from fusion_tcv import references
from fusion_tcv import rewards_used
# Used in TCV#70915
def fundamental_capability() -> environment.Environment:
return environment.Environment(
shot_condition=environment.ShotCondition(70166, 0.0872),
reward=rewards_used.FUNDAMENTAL_CAPABILITY,
reference_generator=references.fundamental_capability(),
max_episode_length=10000)
# Used in TCV#70920
def elongation() -> environment.Environment:
return environment.Environment(
shot_condition=environment.ShotCondition(70166, 0.45),
reward=rewards_used.ELONGATION,
reference_generator=references.elongation(),
max_episode_length=5000)
# Used in TCV#70600
def iter() -> environment.Environment: # pylint: disable=redefined-builtin
return environment.Environment(
shot_condition=environment.ShotCondition(70392, 0.0872),
reward=rewards_used.ITER,
reference_generator=references.iter(),
max_episode_length=1000)
# Used in TCV#70457
def negative_triangularity() -> environment.Environment:
return environment.Environment(
shot_condition=environment.ShotCondition(70166, 0.45),
reward=rewards_used.NEGATIVE_TRIANGULARITY,
reference_generator=references.negative_triangularity(),
max_episode_length=5000)
# Used in TCV#70755
def snowflake() -> environment.Environment:
return environment.Environment(
shot_condition=environment.ShotCondition(70166, 0.0872),
reward=rewards_used.SNOWFLAKE,
reference_generator=references.snowflake(),
max_episode_length=10000)
# Used in TCV#69545
def droplet() -> environment.Environment:
return environment.Environment(
shot_condition=environment.ShotCondition(69198, 0.418),
reward=rewards_used.DROPLETS,
reference_generator=references.droplet(),
max_episode_length=2000)

View File

@@ -0,0 +1,79 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for experiments."""
from absl.testing import absltest
from absl.testing import parameterized
from dm_env import test_utils
from fusion_tcv import agent
from fusion_tcv import experiments
from fusion_tcv import run_loop
class FundamentalCapabilityTest(test_utils.EnvironmentTestMixin,
absltest.TestCase):
def make_object_under_test(self):
return experiments.fundamental_capability()
class ElongationTest(test_utils.EnvironmentTestMixin, absltest.TestCase):
def make_object_under_test(self):
return experiments.elongation()
class IterTest(test_utils.EnvironmentTestMixin, absltest.TestCase):
def make_object_under_test(self):
return experiments.iter()
class NegativeTriangularityTest(test_utils.EnvironmentTestMixin,
absltest.TestCase):
def make_object_under_test(self):
return experiments.negative_triangularity()
class SnowflakeTest(test_utils.EnvironmentTestMixin, absltest.TestCase):
def make_object_under_test(self):
return experiments.snowflake()
class DropletTest(test_utils.EnvironmentTestMixin, absltest.TestCase):
def make_object_under_test(self):
return experiments.droplet()
class ExperimentsTest(parameterized.TestCase):
@parameterized.named_parameters(
("fundamental_capability", experiments.fundamental_capability),
("elongation", experiments.elongation),
("iter", experiments.iter),
("negative_triangularity", experiments.negative_triangularity),
("snowflake", experiments.snowflake),
("droplet", experiments.droplet),
)
def test_env(self, env_fn):
traj = run_loop.run_loop(env_fn(), agent.ZeroAgent(), max_steps=10)
self.assertGreaterEqual(len(traj.reward), 1)
if __name__ == "__main__":
absltest.main()

77
fusion_tcv/fge_octave.py Normal file
View File

@@ -0,0 +1,77 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Actually interact with FGE via octave."""
from typing import Dict, List
import dataclasses
import numpy as np
from fusion_tcv import fge_state
from fusion_tcv import param_variation
SUBSTEPS = 5
@dataclasses.dataclass
class ShotCondition:
"""Represents a shot and time from a real shot."""
shot: int
time: float
class FGESimulatorOctave:
"""Would interact with the FGE solver via Octave.
Given that FGE isn't open source, this is just a sketch.
"""
def __init__(
self,
shot_condition: ShotCondition,
power_supply_delays: Dict[str, List[float]]):
"""Initialize the simulator.
Args:
shot_condition: A ShotCondition, specifying shot number and time. This
specifies the machine geometry (eg with or without the baffles), and the
initial measurements, voltages, current and plasma shape.
power_supply_delays: A dict with power supply delays (in seconds), keys
are coil type labels ('E', 'F', 'G', 'OH'). `None` means default delays.
"""
del power_supply_delays
# Initialize the simulator:
# - Use oct2py to load FGE through Octave.
# - Load the data for the shot_condition.
# - Set up the reactor geometry from the shot_condition.
# - Set the timestep to `tcv_common.DT / SUBSTEPS`.
# - Set up the solver for singlets or droplets based on the shot_condition.
self._num_plasmas = 2 if shot_condition.shot == 69198 else 1
# - Set up the power supply, including the limits, initial data, and delays.
def reset(self, variation: param_variation.Settings) -> fge_state.FGEState:
"""Restarts the simulator with parameters."""
del variation
# Update the simulator with the current physics parameters.
# Reset to the initial state from the shot_condition.
return fge_state.FGEState(self._num_plasmas) # Filled with the real state.
def step(self, voltages: np.ndarray) -> fge_state.FGEState:
"""Run the simulator with `voltages`, returns the state."""
del voltages
# for _ in range(SUBSTEPS):
# Step the simulator with `voltages`.
# raise fge_state.InvalidSolutionError if the solver doesn't converge.
# raise fge_state.StopSignalException if an internal termination triggered
return fge_state.FGEState(self._num_plasmas) # Filled with the real state.

121
fusion_tcv/fge_state.py Normal file
View File

@@ -0,0 +1,121 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A nice python representation of the underlying FGE state."""
from typing import List, Tuple
import numpy as np
from fusion_tcv import shape
from fusion_tcv import shapes_known
from fusion_tcv import tcv_common
class StopSignalException(Exception): # pylint: disable=g-bad-exception-name
"""This is raised if the FGE environment raises the Stop/Alarm signal."""
pass
class InvalidSolutionError(RuntimeError):
"""This is raised if returned solution is invalid."""
pass
class UnhandledOctaveError(Exception):
"""This is raised if some Octave code raises an unhandled error."""
pass
class FGEState:
"""A nice python representation of the underlying FGE State.
Given that FGE isn't open source, all of these numbers are made up, and only
a sketch of what it could look like.
"""
def __init__(self, num_plasmas):
self._num_plasmas = num_plasmas
@property
def num_plasmas(self) -> int:
return self._num_plasmas # Return 1 for singlet, 2 for droplets.
@property
def rzip_d(self) -> Tuple[List[float], List[float], List[float]]:
"""Returns the R, Z, and Ip for each plasma domain."""
if self.num_plasmas == 1:
return [0.9], [0], [-120000]
else:
return [0.9, 0.88], [0.4, -0.4], [-60000, -65000]
def get_coil_currents_by_type(self, coil_type) -> np.ndarray:
currents = tcv_common.TCV_ACTION_RANGES.new_random_named_array()
return currents[coil_type] * tcv_common.ENV_COIL_MAX_CURRENTS[coil_type] / 5
def get_lcfs_points(self, domain: int) -> shape.ShapePoints:
del domain # Should be plasma domain specific
return shapes_known.SHAPE_70166_0872.canonical().points
def get_observation_vector(self) -> np.ndarray:
return tcv_common.TCV_MEASUREMENT_RANGES.new_random_named_array().array
@property
def elongation(self) -> List[float]:
return [1.4] * self.num_plasmas
@property
def triangularity(self) -> List[float]:
return [0.25] * self.num_plasmas
@property
def radius(self) -> List[float]:
return [0.23] * self.num_plasmas
@property
def limit_point_d(self) -> List[shape.Point]:
return [shape.Point(tcv_common.INNER_LIMITER_R, 0.2)] * self.num_plasmas
@property
def is_diverted_d(self) -> List[bool]:
return [False] * self.num_plasmas
@property
def x_points(self) -> shape.ShapePoints:
return []
@property
def flux(self) -> np.ndarray:
"""Return the flux at the grid coordinates."""
return np.random.random((len(self.z_coordinates), len(self.r_coordinates)))
@property
def magnetic_axis_flux_strength(self) -> float:
"""The magnetic flux at the center of the plasma."""
return 2
@property
def lcfs_flux_strength(self) -> float:
"""The flux at the LCFS."""
return 1
@property
def r_coordinates(self) -> np.ndarray:
"""The radial coordinates of the simulation."""
return np.arange(tcv_common.INNER_LIMITER_R, tcv_common.OUTER_LIMITER_R,
tcv_common.LIMITER_WIDTH / 10) # Made up grid resolution.
@property
def z_coordinates(self):
"""The vertical coordinates of the simulation."""
return np.arange(-0.75, 0.75, 1.5 / 30) # Made up numbers.

127
fusion_tcv/named_array.py Normal file
View File

@@ -0,0 +1,127 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Give names to parts of a numpy array."""
from typing import Iterable, List, Mapping, MutableMapping, Tuple, Union
import numpy as np
def lengths_to_ranges(
lengths: Mapping[str, int]) -> MutableMapping[str, List[int]]:
"""Eg: {a: 2, b: 3} -> {a: [0, 1], b: [2, 3, 4]} ."""
ranges = {}
start = 0
for key, length in lengths.items():
ranges[key] = list(range(start, start + length))
start += length
return ranges
class NamedRanges:
"""Given a map of {key: count}, give various views into it."""
def __init__(self, counts: Mapping[str, int]):
self._ranges = lengths_to_ranges(counts)
self._size = sum(counts.values())
def __getitem__(self, name) -> List[int]:
return self._ranges[name]
def __contains__(self, name) -> bool:
return name in self._ranges
def set_range(self, name: str, value: List[int]):
"""Overwrite or create a custom range, which may intersect with others."""
self._ranges[name] = value
def range(self, name: str) -> List[int]:
return self[name]
def index(self, name: str) -> int:
rng = self[name]
if len(rng) != 1:
raise ValueError(f"{name} has multiple values")
return rng[0]
def count(self, name: str) -> int:
return len(self[name])
def names(self) -> Iterable[str]:
return self._ranges.keys()
def ranges(self) -> Iterable[Tuple[str, List[int]]]:
return self._ranges.items()
def counts(self) -> Mapping[str, int]:
return {k: len(v) for k, v in self._ranges.items()}
@property
def size(self) -> int:
return self._size
def named_array(self, array: np.ndarray) -> "NamedArray":
return NamedArray(array, self)
def new_named_array(self) -> "NamedArray":
return NamedArray(np.zeros((self.size,)), self)
def new_random_named_array(self) -> "NamedArray":
return NamedArray(np.random.uniform(size=(self.size,)), self)
class NamedArray:
"""Given a numpy array and a NamedRange, access slices by name."""
def __init__(self, array: np.ndarray, names: NamedRanges):
if array.shape != (names.size,):
raise ValueError(f"Wrong sizes: {array.shape} != ({names.size},)")
self._array = array
self._names = names
def __getitem__(
self, name: Union[str, Tuple[str, Union[int, List[int],
slice]]]) -> np.ndarray:
"""Return a read-only view into the array by name."""
if isinstance(name, str):
arr = self._array[self._names[name]]
else:
name, i = name
arr = self._array[np.array(self._names[name])[i]]
if not np.isscalar(arr):
# Read-only because it's indexed by an array of potentially non-contiguous
# indices, which isn't representable as a normal tensor, which forces a
# copy and therefore writes don't modify the underlying array as expected.
arr.flags.writeable = False
return arr
def __setitem__(
self, name: Union[str, Tuple[str, Union[int, List[int], slice]]], value):
"""Set one or more values of a range to a value."""
if isinstance(name, str):
self._array[self._names[name]] = value
else:
name, i = name
self._array[np.array(self._names[name])[i]] = value
@property
def array(self) -> np.ndarray:
return self._array
@property
def names(self) -> NamedRanges:
return self._names
def to_dict(self) -> Mapping[str, np.ndarray]:
return {k: self[k] for k in self._names.names()}

View File

@@ -0,0 +1,91 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for named_array."""
from absl.testing import absltest
import numpy as np
from fusion_tcv import named_array
class NamedRangesTest(absltest.TestCase):
def test_lengths_to_ranges(self):
self.assertEqual(named_array.lengths_to_ranges({"a": 2, "b": 3}),
{"a": [0, 1], "b": [2, 3, 4]})
def test_named_ranges(self):
action_counts = {"E": 8, "F": 8, "OH": 2, "DUMMY": 1, "G": 1}
actions = named_array.NamedRanges(action_counts)
self.assertEqual(actions.range("E"), list(range(8)))
self.assertEqual(actions["F"], list(range(8, 16)))
self.assertEqual(actions.range("G"), [19])
self.assertEqual(actions.index("G"), 19)
with self.assertRaises(ValueError):
actions.index("F")
for k, v in action_counts.items():
self.assertEqual(actions.count(k), v)
self.assertEqual(actions.counts(), action_counts)
self.assertEqual(list(actions.names()), list(action_counts.keys()))
self.assertEqual(actions.size, sum(action_counts.values()))
refs = actions.new_named_array()
self.assertEqual(refs.array.shape, (actions.size,))
np.testing.assert_array_equal(refs.array, np.zeros((actions.size,)))
refs = actions.new_random_named_array()
self.assertEqual(refs.array.shape, (actions.size,))
self.assertFalse(np.array_equal(refs.array, np.zeros((actions.size,))))
class NamedArrayTest(absltest.TestCase):
def test_name_array(self):
action_counts = {"E": 8, "F": 8, "OH": 2, "DUMMY": 1, "G": 1}
actions_ranges = named_array.NamedRanges(action_counts)
actions_array = np.arange(actions_ranges.size) + 100
actions = named_array.NamedArray(actions_array, actions_ranges)
for k in action_counts:
self.assertEqual(list(actions[k]), [v + 100 for v in actions_ranges[k]])
actions["G"] = -5
self.assertEqual(list(actions["G"]), [-5])
self.assertEqual(actions_array[19], -5)
for i in range(action_counts["E"]):
actions.names.set_range(f"E_{i}", [i])
actions["E_3"] = 53
self.assertEqual(list(actions["E_1"]), [101])
self.assertEqual(list(actions["E_3"]), [53])
self.assertEqual(actions_array[3], 53)
actions["F", 2] = 72
self.assertEqual(actions_array[10], 72)
actions["F", [4, 5]] = 74
self.assertEqual(actions_array[12], 74)
self.assertEqual(actions_array[13], 74)
actions["F", 0:2] = 78
self.assertEqual(actions_array[8], 78)
self.assertEqual(actions_array[9], 78)
self.assertEqual(list(actions["F"]), [78, 78, 72, 111, 74, 74, 114, 115])
with self.assertRaises(ValueError):
actions["F"][5] = 85
if __name__ == "__main__":
absltest.main()

117
fusion_tcv/noise.py Normal file
View File

@@ -0,0 +1,117 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Settings for adding noise to the action and measurements."""
import numpy as np
from numpy import random
from fusion_tcv import tcv_common
class Noise:
"""Class for adding noise to the action and measurements."""
def __init__(self,
action_mean=None,
action_std=None,
measurements_mean=None,
measurements_std=None,
seed=None):
"""Initializes the class.
Args:
action_mean: mean of the Gaussian noise (action bias).
action_std: std of the Gaussian action noise.
measurements_mean: mean of the Gaussian noise (measurement bias).
measurements_std: Dictionary mapping the tcv measurement names to noise.
seed: seed for the random number generator. If none seed is unset.
"""
# Check all of the shapes are present and correct.
assert action_std.shape == (tcv_common.NUM_ACTIONS,)
assert action_mean.shape == (tcv_common.NUM_ACTIONS,)
for name, num in tcv_common.TCV_MEASUREMENTS.items():
assert name in measurements_std
assert measurements_mean[name].shape == (num,)
assert measurements_std[name].shape == (num,)
self._action_mean = action_mean
self._action_std = action_std
self._meas_mean = measurements_mean
self._meas_std = measurements_std
self._meas_mean_vec = tcv_common.dict_to_measurement(self._meas_mean)
self._meas_std_vec = tcv_common.dict_to_measurement(self._meas_std)
self._gen = random.RandomState(seed)
@classmethod
def use_zero_noise(cls):
no_noise_mean = dict()
no_noise_std = dict()
for name, num in tcv_common.TCV_MEASUREMENTS.items():
no_noise_mean[name] = np.zeros((num,))
no_noise_std[name] = np.zeros((num,))
return cls(
action_mean=np.zeros((tcv_common.NUM_ACTIONS)),
action_std=np.zeros((tcv_common.NUM_ACTIONS)),
measurements_mean=no_noise_mean,
measurements_std=no_noise_std)
@classmethod
def use_default_noise(cls, scale=1):
"""Returns the default observation noise parameters."""
# There is no noise added to the actions, because the noise should be added
# to the action after/as part of the power supply model as opposed to the
# input to the power supply model.
action_noise_mean = np.zeros((tcv_common.NUM_ACTIONS))
action_noise_std = np.zeros((tcv_common.NUM_ACTIONS))
meas_noise_mean = dict()
for key, l in tcv_common.TCV_MEASUREMENTS.items():
meas_noise_mean[key] = np.zeros((l,))
meas_noise_std = dict(
clint_vloop=np.array([0]),
clint_rvloop=np.array([scale * 1e-4] * 37),
bm=np.array([scale * 1e-4] * 38),
IE=np.array([scale * 20] * 8),
IF=np.array([scale * 5] * 8),
IOH=np.array([scale * 20] *2),
Bdot=np.array([scale * 0.05] * 20),
DIOH=np.array([scale * 30]),
FIR_FRINGE=np.array([0]),
IG=np.array([scale * 2.5]),
ONEMM=np.array([0]),
vloop=np.array([scale * 0.3]),
IPHI=np.array([0]),
)
return cls(
action_mean=action_noise_mean,
action_std=action_noise_std,
measurements_mean=meas_noise_mean,
measurements_std=meas_noise_std)
def add_action_noise(self, action):
errs = self._gen.normal(size=action.shape,
loc=self._action_mean,
scale=self._action_std)
return action + errs
def add_measurement_noise(self, measurement_vec):
errs = self._gen.normal(size=measurement_vec.shape,
loc=self._meas_mean_vec,
scale=self._meas_std_vec)
# Make the IOH measurements consistent. The "real" measurements are IOH
# and DIOH, so use those.
errs = tcv_common.measurements_to_dict(errs)
errs["IOH"][1] = errs["IOH"][0] + errs["DIOH"][0]
errs = tcv_common.dict_to_measurement(errs)
return measurement_vec + errs

View File

@@ -0,0 +1,126 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for varying parameters from simulation to simulation."""
from typing import Dict, Optional, Tuple
import dataclasses
import numpy as np
from fusion_tcv import tcv_common
# Pylint does not like variable names like `qA`.
# pylint: disable=invalid-name
RP_DEFAULT = 5e-6
LP_DEFAULT = 2.05e-6
BP_DEFAULT = 0.25
QA_DEFAULT = 1.3
@dataclasses.dataclass
class Settings:
"""Settings to modify solver/plasma model."""
# Inverse of the resistivity.
# Plasma circuit equation is roughly
# k * dIoh/dt = L * dIp/dt + R * I = Vloop
# where R is roughly (1 / signeo) or rp.
# Value is multiplier on the default value.
# This parameter does not apply to the OhmTor diffusion.
signeo: Tuple[float, float] = (1, 1)
# Rp Plasma resistivity. The value is an absolute value.
rp: float = RP_DEFAULT
# Plasma self-inductance. The value is an absolute value.
lp: float = LP_DEFAULT
# Proportional to the plasma pressure. The value is an absolute value.
bp: float = BP_DEFAULT
# Plasma current profile. Value is absolute.
qA: float = QA_DEFAULT
# Initial OH coil current. Applied to both coils.
ioh: Optional[float] = None
# The voltage offsets for the various coils.
psu_voltage_offset: Optional[Dict[str, float]] = None
def _psu_voltage_offset_string(self) -> str:
"""Return a short-ish, readable string of the psu voltage offsets."""
if not self.psu_voltage_offset:
return "None"
if len(self.psu_voltage_offset) < 8: # Only a few, output individually.
return ", ".join(
f"{coil.replace('_00', '')}: {offset:.0f}"
for coil, offset in self.psu_voltage_offset.items())
# Otherwise, too long, so output in groups.
groups = []
for coil, action_range in tcv_common.TCV_ACTION_RANGES.ranges():
offsets = [self.psu_voltage_offset.get(tcv_common.TCV_ACTIONS[i], 0)
for i in action_range]
if any(offsets):
groups.append(f"{coil}: " + ",".join(f"{offset:.0f}"
for offset in offsets))
return ", ".join(groups)
class ParamGenerator:
"""Varies parameters using uniform/loguniform distributions.
Absolute parameters are varied using uniform distributions while scaling
parameters use a loguniform distribution.
"""
def __init__(self,
rp_bounds: Optional[Tuple[float, float]] = None,
lp_bounds: Optional[Tuple[float, float]] = None,
qA_bounds: Optional[Tuple[float, float]] = None,
bp_bounds: Optional[Tuple[float, float]] = None,
rp_mean: float = RP_DEFAULT,
lp_mean: float = LP_DEFAULT,
bp_mean: float = BP_DEFAULT,
qA_mean: float = QA_DEFAULT,
ioh_bounds: Optional[Tuple[float, float]] = None,
psu_voltage_offset_bounds: Optional[
Dict[str, Tuple[float, float]]] = None):
# Do not allow Signeo variation as this does not work with OhmTor current
# diffusion.
no_scaling = (1, 1)
self._rp_bounds = rp_bounds if rp_bounds else no_scaling
self._lp_bounds = lp_bounds if lp_bounds else no_scaling
self._bp_bounds = bp_bounds if bp_bounds else no_scaling
self._qA_bounds = qA_bounds if qA_bounds else no_scaling
self._rp_mean = rp_mean
self._lp_mean = lp_mean
self._bp_mean = bp_mean
self._qA_mean = qA_mean
self._ioh_bounds = ioh_bounds
self._psu_voltage_offset_bounds = psu_voltage_offset_bounds
def generate(self) -> Settings:
return Settings(
signeo=(1, 1),
rp=loguniform_rv(*self._rp_bounds) * self._rp_mean,
lp=loguniform_rv(*self._lp_bounds) * self._lp_mean,
bp=loguniform_rv(*self._bp_bounds) * self._bp_mean,
qA=loguniform_rv(*self._qA_bounds) * self._qA_mean,
ioh=np.random.uniform(*self._ioh_bounds) if self._ioh_bounds else None,
psu_voltage_offset=(
{coil: np.random.uniform(*bounds)
for coil, bounds in self._psu_voltage_offset_bounds.items()}
if self._psu_voltage_offset_bounds else None))
def loguniform_rv(lower, upper):
"""Generate loguniform random variable between min and max."""
if lower == upper:
return lower
assert lower < upper
return np.exp(np.random.uniform(np.log(lower), np.log(upper)))

202
fusion_tcv/ref_gen.py Normal file
View File

@@ -0,0 +1,202 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generators for References vector."""
import abc
import copy
from typing import List, Optional
import dataclasses
from fusion_tcv import named_array
from fusion_tcv import shape as shape_lib
from fusion_tcv import tcv_common
class AbstractReferenceGenerator(abc.ABC):
"""Abstract class for generating the reference signal."""
@abc.abstractmethod
def reset(self) -> named_array.NamedArray:
"""Resets the class for a new episode and returns the first reference."""
@abc.abstractmethod
def step(self) -> named_array.NamedArray:
"""Returns the reference signal."""
@dataclasses.dataclass
class LinearTransition:
reference: named_array.NamedArray # Reference at which to end the transition.
transition_steps: int # Number of intermediate steps between the shapes.
steady_steps: int # Number of steps in the steady state.
class LinearTransitionReferenceGenerator(AbstractReferenceGenerator):
"""A base class for generating references that are a series of transitions."""
def __init__(self, start_offset: int = 0):
self._last_ref = None
self._reset_counters()
self._start_offset = start_offset
@abc.abstractmethod
def _next_transition(self) -> LinearTransition:
"""Override this in the subclass."""
def reset(self) -> named_array.NamedArray:
self._last_ref = None
self._reset_counters()
for _ in range(self._start_offset):
self.step()
return self.step()
def _reset_counters(self):
self._steady_step = 0
self._transition_step = 0
self._transition = None
def step(self) -> named_array.NamedArray:
if (self._transition is None or
self._steady_step == self._transition.steady_steps):
if self._transition is not None:
self._last_ref = self._transition.reference
self._reset_counters()
self._transition = self._next_transition()
# Ensure at least one steady step in middle transitions.
# If we would like this to not have to be true, we need to change the
# logic below which assumes there is at least one step in the steady
# phase.
assert self._transition.steady_steps > 0
assert self._transition is not None # to make pytype happy
transition_steps = self._transition.transition_steps
if self._last_ref is None: # No transition at beginning of episode.
transition_steps = 0
if self._transition_step < transition_steps: # In transition phase.
self._transition_step += 1
a = self._transition_step / (self._transition.transition_steps + 1) # pytype: disable=attribute-error
return self._last_ref.names.named_array(
self._last_ref.array * (1 - a) + self._transition.reference.array * a) # pytype: disable=attribute-error
else: # In steady phase.
self._steady_step += 1
return copy.deepcopy(self._transition.reference)
class FixedReferenceGenerator(LinearTransitionReferenceGenerator):
"""Generates linear transitions from a fixed set of references."""
def __init__(self, transitions: List[LinearTransition],
start_offset: int = 0):
self._transitions = transitions
self._current_transition = 0
super().__init__(start_offset=start_offset)
def reset(self) -> named_array.NamedArray:
self._current_transition = 0
return super().reset()
def _next_transition(self) -> LinearTransition:
if self._current_transition == len(self._transitions):
# Have gone through all of the transitions. Return the final reference
# for a very long time.
return LinearTransition(steady_steps=50000, transition_steps=0,
reference=self._transitions[-1].reference)
self._current_transition += 1
return copy.deepcopy(self._transitions[self._current_transition - 1])
@dataclasses.dataclass
class TimedTransition:
steady_steps: int # Number of steps to hold the shape.
transition_steps: int # Number of steps to transition.
@dataclasses.dataclass
class ParametrizedShapeTimedTarget:
"""RZIP condition with a timestep attached."""
shape: shape_lib.Shape
timing: TimedTransition
class PresetShapePointsReferenceGenerator(FixedReferenceGenerator):
"""Generates a fixed set of shape points."""
def __init__(
self, targets: List[ParametrizedShapeTimedTarget], start_offset: int = 0):
if targets[0].timing.transition_steps != 0:
raise ValueError("Invalid first timing, transition must be 0, not "
f"{targets[0].timing.transition_steps}")
transitions = []
for target in targets:
transitions.append(LinearTransition(
steady_steps=target.timing.steady_steps,
transition_steps=target.timing.transition_steps,
reference=target.shape.canonical().gen_references()))
super().__init__(transitions, start_offset=start_offset)
class ShapeFromShot(PresetShapePointsReferenceGenerator):
"""Generate shapes from EPFL references."""
def __init__(
self, time_slices: List[shape_lib.ReferenceTimeSlice],
start: Optional[float] = None):
"""Given a series of time slices, start from time_slice.time==start."""
if start is None:
start = time_slices[0].time
dt = 1e-4
targets = []
time_slices = shape_lib.canonicalize_reference_series(time_slices)
prev = None
for i, ref in enumerate(time_slices):
assert prev is None or prev.hold < ref.time
if ref.time < start:
continue
if prev is None and start != ref.time:
raise ValueError("start must be one of the time slice times.")
steady = (max(1, int((ref.hold - ref.time) / dt))
if i < len(time_slices) - 1 else 100000)
transition = (0 if prev is None else
(int((ref.time - prev.time) / dt) -
max(1, int((prev.hold - prev.time) / dt))))
targets.append(ParametrizedShapeTimedTarget(
shape=ref.shape,
timing=TimedTransition(
steady_steps=steady, transition_steps=transition)))
prev = ref
assert targets
super().__init__(targets)
@dataclasses.dataclass
class RZIpTarget:
r: float
z: float
ip: float
def make_symmetric_multidomain_rzip_reference(
target: RZIpTarget) -> named_array.NamedArray:
"""Generate multi-domain rzip references."""
refs = tcv_common.REF_RANGES.new_named_array()
refs["R"] = (target.r, target.r)
refs["Z"] = (target.z, -target.z)
refs["Ip"] = (target.ip, target.ip)
return refs

1589
fusion_tcv/references.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,58 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Print the values of references."""
from typing import Sequence
from absl import app
from absl import flags
from fusion_tcv import named_array
from fusion_tcv import references
from fusion_tcv import tcv_common
_refs = flags.DEFINE_enum("refs", None, references.REFERENCES.keys(),
"Which references to print")
_count = flags.DEFINE_integer("count", 100, "How many timesteps to print.")
_freq = flags.DEFINE_integer("freq", 1, "Print only every so often.")
_fields = flags.DEFINE_multi_enum(
"field", None, tcv_common.REF_RANGES.names(),
"Which reference fields to print, default of all.")
flags.mark_flag_as_required("refs")
def print_ref(step: int, ref: named_array.NamedArray):
print(f"Step: {step}")
for k in (_fields.value or ref.names.names()):
print(f" {k}: [{', '.join(f'{v:.3f}' for v in ref[k])}]")
def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
if _freq.value <= 0:
raise app.UsageError("`freq` must be >0.")
ref = references.REFERENCES[_refs.value]()
print_ref(0, ref.reset())
for i in range(1, _count.value + 1):
for _ in range(_freq.value - 1):
ref.step()
print_ref(i * _freq.value, ref.step())
print(f"Stopped after {_count.value * _freq.value} steps.")
if __name__ == "__main__":
app.run(main)

View File

@@ -0,0 +1,4 @@
absl-py>=0.13.0
dm-env>=1.5
numpy>=1.21.0
scipy>=1.7.0

185
fusion_tcv/rewards.py Normal file
View File

@@ -0,0 +1,185 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Reward function for the fusion environment."""
import abc
import collections
import functools
from typing import Callable, Dict, List, Optional, Text, Tuple, Union
from absl import logging
import dataclasses
import numpy as np
from fusion_tcv import combiners
from fusion_tcv import fge_state
from fusion_tcv import named_array
from fusion_tcv import targets as targets_lib
from fusion_tcv import transforms
class AbstractMeasure(abc.ABC):
@abc.abstractmethod
def __call__(self, targets: List[targets_lib.Target]) -> List[float]:
"""Returns a list of error measures."""
class AbsDist(AbstractMeasure):
"""Return the absolute distance between the actual and target."""
@staticmethod
def __call__(targets: List[targets_lib.Target]) -> List[float]:
return [abs(t.actual - t.target) for t in targets]
@dataclasses.dataclass(frozen=True)
class MeasureDetails:
min: float
mean: float
max: float
@dataclasses.dataclass
class RewardDetails:
reward: float # 0-1 reward value.
weighted: float # Should sum to < 0-1.
weight: float
measure: Optional[MeasureDetails] = None
class AbstractReward(abc.ABC):
"""Abstract reward class."""
@abc.abstractmethod
def reward(
self,
voltages: np.ndarray,
state: fge_state.FGEState,
references: named_array.NamedArray,
) -> Tuple[float, Dict[Text, List[RewardDetails]]]:
"""Returns the reward and log dict as a function of the penalty term."""
@abc.abstractmethod
def terminal_reward(self) -> float:
"""Returns the reward if the simulator crashed."""
WeightFn = Callable[[named_array.NamedArray], float]
WeightOrFn = Union[float, WeightFn]
@dataclasses.dataclass
class Component:
target: targets_lib.AbstractTarget
transforms: List[transforms.AbstractTransform]
measure: AbstractMeasure = dataclasses.field(default_factory=AbsDist)
weight: Union[WeightOrFn, List[WeightOrFn]] = 1
name: Optional[str] = None
class Reward(AbstractReward):
"""Combines a bunch of reward components into a single reward.
The component parts are applied in the order: target, measure, transform.
- Targets represent some error value as one or more pair of values
(target, actual), usually with some meaningful physical unit (eg distance,
volts, etc).
- Measures combine the (target, actual) into a single float, for example
absolute distance, for each error value.
- Transforms can make arbitrary conversions, but one of them must change from
the arbitrary (often meaningful) scale to a reward in the 0-1 range.
- Combiners are a special type of transform that reduces a vector of values
down to a single value. The combiner can be skipped if the target only
outputs a single value, or if you want a vector of outputs for the final
combiner.
- The component weights are passed to the final combiner, and must match the
number of outputs for that component.
"""
def __init__(self,
components: List[Component],
combiner: combiners.AbstractCombiner,
terminal_reward: float = -5,
reward_scale: float = 0.01):
self._components = components
self._combiner = combiner
self._terminal_reward = terminal_reward
self._reward_scale = reward_scale
self._weights = []
component_count = collections.Counter()
for component in self._components:
num_outputs = component.target.outputs
for transform in component.transforms:
if transform.outputs is not None:
num_outputs = transform.outputs
if not isinstance(component.weight, list):
component.weight = [component.weight]
if len(component.weight) != num_outputs:
name = component.name or component.target.name
raise ValueError(f"Wrong number of weights for '{name}': got:"
f" {len(component.weight)}, expected: {num_outputs}")
self._weights.extend(component.weight)
def terminal_reward(self) -> float:
return self._terminal_reward * self._reward_scale
def reward(
self,
voltages: np.ndarray,
state: fge_state.FGEState,
references: named_array.NamedArray,
) -> Tuple[float, Dict[Text, List[RewardDetails]]]:
values = []
weights = [weight(references) if callable(weight) else weight
for weight in self._weights]
reward_dict = collections.defaultdict(list)
for component in self._components:
name = component.name or component.target.name
num_outputs = len(component.weight)
component_weights = weights[len(values):(len(values) + num_outputs)]
try:
target = component.target(voltages, state, references)
except targets_lib.TargetError:
logging.exception("Target failed.")
# Failed turns into minimum reward.
measure = [987654321] * num_outputs
transformed = [0] * num_outputs
else:
measure = component.measure(target)
transformed = functools.reduce(
(lambda e, fn: fn(e)), component.transforms, measure)
assert len(transformed) == num_outputs
for v in transformed:
if not np.isnan(v) and not 0 <= v <= 1:
raise ValueError(f"The transformed value in {name} is invalid: {v}")
values.extend(transformed)
for weight, value in zip(component_weights, transformed):
measure = [m for m in measure if not np.isnan(m)] or [float("nan")]
reward_dict[name].append(RewardDetails(
value, weight * value * self._reward_scale,
weight if not np.isnan(value) else 0,
MeasureDetails(
min(measure), sum(measure) / len(measure), max(measure))))
sum_weights = sum(sum(d.weight for d in detail)
for detail in reward_dict.values())
for reward_details in reward_dict.values():
for detail in reward_details:
detail.weighted /= sum_weights
final_combined = self._combiner(values, weights)
assert len(final_combined) == 1
return final_combined[0] * self._reward_scale, reward_dict

253
fusion_tcv/rewards_used.py Normal file
View File

@@ -0,0 +1,253 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The rewards used in our experiments."""
from fusion_tcv import combiners
from fusion_tcv import rewards
from fusion_tcv import targets
from fusion_tcv import transforms
# Used in TCV#70915
FUNDAMENTAL_CAPABILITY = rewards.Reward([
rewards.Component(
target=targets.ShapeLCFSDistance(),
transforms=[transforms.SoftPlus(good=0.005, bad=0.05),
combiners.SmoothMax(-1)]),
rewards.Component(
target=targets.XPointFar(),
transforms=[transforms.Sigmoid(good=0.3, bad=0.1),
combiners.SmoothMax(-5)]),
rewards.Component(
target=targets.LimitPoint(),
transforms=[transforms.Sigmoid(bad=0.2, good=0.1)]),
rewards.Component(
target=targets.XPointNormalizedFlux(num_points=1),
transforms=[transforms.SoftPlus(bad=0.08)]),
rewards.Component(
target=targets.XPointDistance(num_points=1),
transforms=[transforms.Sigmoid(good=0.01, bad=0.15)]),
rewards.Component(
target=targets.XPointFluxGradient(num_points=1),
transforms=[transforms.SoftPlus(bad=3)],
weight=0.5),
rewards.Component(
target=targets.Ip(),
transforms=[transforms.SoftPlus(good=500, bad=20000)]),
rewards.Component(
target=targets.OHCurrentsClose(),
transforms=[transforms.SoftPlus(good=50, bad=1050)]),
], combiners.SmoothMax(-0.5))
# Used in TCV#70920
ELONGATION = rewards.Reward([
rewards.Component(
target=targets.ShapeLCFSDistance(),
transforms=[transforms.SoftPlus(good=0.003, bad=0.03),
combiners.SmoothMax(-1)],
weight=3),
rewards.Component(
target=targets.ShapeRadius(),
transforms=[transforms.SoftPlus(good=0.002, bad=0.02)]),
rewards.Component(
target=targets.ShapeElongation(),
transforms=[transforms.SoftPlus(good=0.005, bad=0.2)]),
rewards.Component(
target=targets.ShapeTriangularity(),
transforms=[transforms.SoftPlus(good=0.005, bad=0.2)]),
rewards.Component(
target=targets.XPointCount(),
transforms=[transforms.Equal()]),
rewards.Component(
target=targets.LimitPoint(), # Stay away from the top/baffles.
transforms=[transforms.Sigmoid(bad=0.3, good=0.2)]),
rewards.Component(
target=targets.Ip(),
transforms=[transforms.SoftPlus(good=500, bad=30000)]),
rewards.Component(
target=targets.VoltageOOB(),
transforms=[combiners.Mean(), transforms.SoftPlus(bad=1)]),
rewards.Component(
target=targets.OHCurrentsClose(),
transforms=[transforms.ClippedLinear(good=50, bad=1050)]),
rewards.Component(
name="CurrentsFarFromZero",
target=targets.EFCurrents(),
transforms=[transforms.Abs(),
transforms.SoftPlus(good=100, bad=50),
combiners.GeometricMean()]),
], combiner=combiners.SmoothMax(-5))
# Used in TCV#70600
ITER = rewards.Reward([
rewards.Component(
target=targets.ShapeLCFSDistance(),
transforms=[transforms.SoftPlus(good=0.005, bad=0.05),
combiners.SmoothMax(-1)],
weight=3),
rewards.Component(
target=targets.Diverted(),
transforms=[transforms.Equal()]),
rewards.Component(
target=targets.XPointNormalizedFlux(num_points=2),
transforms=[transforms.SoftPlus(bad=0.08)],
weight=[1] * 2),
rewards.Component(
target=targets.XPointDistance(num_points=2),
transforms=[transforms.Sigmoid(good=0.01, bad=0.15)],
weight=[0.5] * 2),
rewards.Component(
target=targets.XPointFluxGradient(num_points=2),
transforms=[transforms.SoftPlus(bad=3)],
weight=[0.5] * 2),
rewards.Component(
target=targets.LegsNormalizedFlux(),
transforms=[transforms.Sigmoid(good=0.1, bad=0.3),
combiners.SmoothMax(-5)],
weight=2),
rewards.Component(
target=targets.Ip(),
transforms=[transforms.SoftPlus(good=500, bad=20000)],
weight=2),
rewards.Component(
target=targets.VoltageOOB(),
transforms=[combiners.Mean(), transforms.SoftPlus(bad=1)]),
rewards.Component(
target=targets.OHCurrentsClose(),
transforms=[transforms.ClippedLinear(good=50, bad=1050)]),
rewards.Component(
name="CurrentsFarFromZero",
target=targets.EFCurrents(),
transforms=[transforms.Abs(),
transforms.SoftPlus(good=100, bad=50),
combiners.GeometricMean()]),
], combiner=combiners.SmoothMax(-5))
# Used in TCV#70755
SNOWFLAKE = rewards.Reward([
rewards.Component(
target=targets.ShapeLCFSDistance(),
transforms=[transforms.SoftPlus(good=0.005, bad=0.05),
combiners.SmoothMax(-1)],
weight=3),
rewards.Component(
target=targets.LimitPoint(),
transforms=[transforms.Sigmoid(bad=0.2, good=0.1)]),
rewards.Component(
target=targets.XPointNormalizedFlux(num_points=2),
transforms=[transforms.SoftPlus(bad=0.08)],
weight=[1] * 2),
rewards.Component(
target=targets.XPointDistance(num_points=2),
transforms=[transforms.Sigmoid(good=0.01, bad=0.15)],
weight=[0.5] * 2),
rewards.Component(
target=targets.XPointFluxGradient(num_points=2),
transforms=[transforms.SoftPlus(bad=3)],
weight=[0.5] * 2),
rewards.Component(
target=targets.LegsNormalizedFlux(),
transforms=[transforms.Sigmoid(good=0.1, bad=0.3),
combiners.SmoothMax(-5)],
weight=2),
rewards.Component(
target=targets.Ip(),
transforms=[transforms.SoftPlus(good=500, bad=20000)],
weight=2),
rewards.Component(
target=targets.VoltageOOB(),
transforms=[combiners.Mean(), transforms.SoftPlus(bad=1)]),
rewards.Component(
target=targets.OHCurrentsClose(),
transforms=[transforms.ClippedLinear(good=50, bad=1050)]),
rewards.Component(
name="CurrentsFarFromZero",
target=targets.EFCurrents(),
transforms=[transforms.Abs(),
transforms.SoftPlus(good=100, bad=50),
combiners.GeometricMean()]),
], combiner=combiners.SmoothMax(-5))
# Used in TCV#70457
NEGATIVE_TRIANGULARITY = rewards.Reward([
rewards.Component(
target=targets.ShapeLCFSDistance(),
transforms=[transforms.SoftPlus(good=0.005, bad=0.05),
combiners.SmoothMax(-1)],
weight=3),
rewards.Component(
target=targets.ShapeRadius(),
transforms=[transforms.SoftPlus(bad=0.04)]),
rewards.Component(
target=targets.ShapeElongation(),
transforms=[transforms.SoftPlus(bad=0.5)]),
rewards.Component(
target=targets.ShapeTriangularity(),
transforms=[transforms.SoftPlus(bad=0.5)]),
rewards.Component(
target=targets.Diverted(),
transforms=[transforms.Equal()]),
rewards.Component(
target=targets.XPointNormalizedFlux(num_points=2),
transforms=[transforms.SoftPlus(bad=0.08)],
weight=[1] * 2),
rewards.Component(
target=targets.XPointDistance(num_points=2),
transforms=[transforms.Sigmoid(good=0.02, bad=0.15)],
weight=[0.5] * 2),
rewards.Component(
target=targets.XPointFluxGradient(num_points=2),
transforms=[transforms.SoftPlus(bad=3)],
weight=[0.5] * 2),
rewards.Component(
target=targets.Ip(),
transforms=[transforms.SoftPlus(good=500, bad=20000)],
weight=2),
rewards.Component(
target=targets.VoltageOOB(),
transforms=[combiners.Mean(), transforms.SoftPlus(bad=1)]),
rewards.Component(
target=targets.OHCurrentsClose(),
transforms=[transforms.ClippedLinear(good=50, bad=1050)]),
rewards.Component(
name="CurrentsFarFromZero",
target=targets.EFCurrents(),
transforms=[transforms.Abs(),
transforms.SoftPlus(good=100, bad=50),
combiners.GeometricMean()]),
], combiner=combiners.SmoothMax(-0.5))
# Used in TCV#69545
DROPLETS = rewards.Reward([
rewards.Component(
target=targets.R(indices=[0, 1]),
transforms=[transforms.Sigmoid(good=0.02, bad=0.5)],
weight=[1, 1]),
rewards.Component(
target=targets.Z(indices=[0, 1]),
transforms=[transforms.Sigmoid(good=0.02, bad=0.2)],
weight=[1, 1]),
rewards.Component(
target=targets.Ip(indices=[0, 1]),
transforms=[transforms.Sigmoid(good=2000, bad=20000)],
weight=[1, 1]),
rewards.Component(
target=targets.OHCurrentsClose(),
transforms=[transforms.ClippedLinear(good=50, bad=1050)]),
], combiner=combiners.GeometricMean())

40
fusion_tcv/run_loop.py Normal file
View File

@@ -0,0 +1,40 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run an agent on the environment."""
import numpy as np
from fusion_tcv import environment
from fusion_tcv import trajectory
def run_loop(env: environment.Environment, agent,
max_steps: int = 100000) -> trajectory.Trajectory:
"""Run an agent."""
results = []
agent.reset()
ts = env.reset()
for _ in range(max_steps):
obs = ts.observation
action = agent.step(ts)
ts = env.step(action)
results.append(trajectory.Trajectory(
measurements=obs["measurements"],
references=obs["references"],
actions=action,
reward=np.array(ts.reward)))
if ts.last():
break
return trajectory.Trajectory.stack(results)

468
fusion_tcv/shape.py Normal file
View File

@@ -0,0 +1,468 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for time varying shape control."""
import copy
import enum
import random
from typing import List, Optional, NamedTuple, Tuple, Union
import dataclasses
import numpy as np
from scipy import interpolate
from fusion_tcv import named_array
from fusion_tcv import tcv_common
class Point(NamedTuple):
"""A point in r,z coordinates."""
r: float
z: float
def to_polar(self) -> "PolarPoint":
return PolarPoint(np.arctan2(self.z, self.r),
np.sqrt(self.r**2 + self.z**2))
def __neg__(self):
return Point(-self.r, -self.z)
def __add__(self, pt_or_val: Union["Point", float]):
if isinstance(pt_or_val, Point):
return Point(self.r + pt_or_val.r, self.z + pt_or_val.z)
else:
return Point(self.r + pt_or_val, self.z + pt_or_val)
def __sub__(self, pt_or_val: Union["Point", float]):
if isinstance(pt_or_val, Point):
return Point(self.r - pt_or_val.r, self.z - pt_or_val.z)
else:
return Point(self.r - pt_or_val, self.z - pt_or_val)
def __mul__(self, pt_or_val: Union["Point", float]):
if isinstance(pt_or_val, Point):
return Point(self.r * pt_or_val.r, self.z * pt_or_val.z)
else:
return Point(self.r * pt_or_val, self.z * pt_or_val)
def __truediv__(self, pt_or_val: Union["Point", float]):
if isinstance(pt_or_val, Point):
return Point(self.r / pt_or_val.r, self.z / pt_or_val.z)
else:
return Point(self.r / pt_or_val, self.z / pt_or_val)
__div__ = __truediv__
def dist(p1: Union[Point, np.ndarray], p2: Union[Point, np.ndarray]) -> float:
return np.hypot(*(p1 - p2))
ShapePoints = List[Point]
def to_shape_points(array: np.ndarray) -> ShapePoints:
return [Point(r, z) for r, z in array]
def center_point(points: ShapePoints) -> Point:
return sum(points, Point(0, 0)) / len(points)
class ShapeSide(enum.Enum):
LEFT = 0
RIGHT = 1
NOSHIFT = 2
class PolarPoint(NamedTuple):
angle: float
dist: float
def to_point(self) -> Point:
return Point(np.cos(self.angle) * self.dist, np.sin(self.angle) * self.dist)
def evenly_spaced_angles(num: int):
return np.arange(num) * 2 * np.pi / num
def angle_aligned_dists(points: np.ndarray, angles: np.ndarray) -> np.ndarray:
"""Return a new set of points along angles that intersect with the shape."""
# TODO(tewalds): Walk the two arrays together for an O(n+m) algorithm instead
# of the current O(n*m). This would work as long as they are both sorted
# around the radial direction, so the next intersection will be near the last.
return np.array([dist_angle_to_surface(points, a) for a in angles])
def angle_aligned_points(points: np.ndarray, num_points: int,
origin: Point) -> np.ndarray:
"""Given a set of points, return a new space centered at origin."""
angles = evenly_spaced_angles(num_points)
dists = angle_aligned_dists(points - origin, angles)
return np.stack((np.cos(angles) * dists,
np.sin(angles) * dists), axis=-1) + origin
def dist_angle_to_surface(points: np.ndarray, angle: float) -> float:
"""Distance along a ray to the surface defined by a list of points."""
for p1, p2 in zip(points, np.roll(points, 1, axis=0)):
d = dist_angle_to_segment(p1, p2, angle)
if d is not None:
return d
raise ValueError(f"Intersecting edge not found for angle: {angle}")
def dist_angle_to_segment(p1, p2, angle: float) -> Optional[float]:
"""Distance along a ray from the origin to a segment defined by two points."""
x0, y0 = p1[0], p1[1]
x1, y1 = p2[0], p2[1]
a0, b0 = np.cos(angle), np.sin(angle)
a1, b1 = 0, 0
# Segment/segment algorithm inspired by https://stackoverflow.com/q/563198
denom = (b0 - b1) * (x0 - x1) - (y0 - y1) * (a0 - a1)
if denom == 0:
return None # Angle parallel to the segment, so can't intersect.
xy = (a0 * (y1 - b1) + a1 * (b0 - y1) + x1 * (b1 - b0)) / denom
eps = 0.00001 # Allow intersecting slightly beyond the endpoints.
if -eps <= xy <= 1 + eps: # Check it hit the segment, not just the line.
ab = (y1 * (x0 - a1) + b1 * (x1 - x0) + y0 * (a1 - x1)) / denom
if ab > 0: # Otherwise it hit in the reverse direction.
# If ab <= 1 then it's within the segment defined above, but given it's
# a unit vector with one end at the origin this tells us the distance to
# the intersection of an infinite ray out from the origin.
return ab
return None
def dist_point_to_surface(points: np.ndarray, point: np.ndarray) -> float:
"""Distance from a point to the surface defined by a list of points."""
return min(dist_point_to_segment(p1, p2, point)
for p1, p2 in zip(points, np.roll(points, 1, axis=0)))
def dist_point_to_segment(v: np.ndarray, w: np.ndarray, p: np.ndarray) -> float:
"""Return minimum distance between line segment vw and point p."""
# Inspired by: https://stackoverflow.com/a/1501725
l2 = dist(v, w)**2
if l2 == 0.0:
return dist(p, v) # v == w case
# Consider the line extending the segment, parameterized as v + t (w - v).
# We find projection of point p onto the line.
# It falls where t = [(p-v) . (w-v)] / |w-v|^2
# We clamp t from [0,1] to handle points outside the segment vw.
t = max(0, min(1, np.dot(p - v, w - v) / l2))
projection = v + t * (w - v) # Projection falls on the segment
return dist(p, projection)
def sort_by_angle(points: ShapePoints) -> ShapePoints:
center = sum(points, Point(0, 0)) / len(points)
return sorted(points, key=lambda p: (p - center).to_polar().angle)
def spline_interpolate_points(
points: ShapePoints, num_points: int,
x_points: Optional[ShapePoints] = None) -> ShapePoints:
"""Interpolate along a spline to give a smooth evenly spaced shape."""
ends = []
if x_points:
# Find the shape points that must allow sharp corners.
for xp in x_points:
for i, p in enumerate(points):
if np.hypot(*(p - xp)) < 0.01:
ends.append(i)
if not ends:
# No x-points forcing sharp corners, so use a periodic spline.
tck, _ = interpolate.splprep(np.array(points + [points[0]]).T, s=0, per=1)
unew = np.arange(num_points) / num_points
out = interpolate.splev(unew, tck)
assert len(out[0]) == num_points
return sort_by_angle(to_shape_points(np.array(out).T))
# Generate a spline with an shape==x-point at each end.
new_pts = []
for i, j in zip(ends, ends[1:] + [ends[0]]):
pts = points[i:j+1] if i < j else points[i:] + points[:j+1]
num_segment_points = np.round((len(pts) - 1) / len(points) * num_points)
unew = np.arange(num_segment_points + 1) / num_segment_points
tck, _ = interpolate.splprep(np.array(pts).T, s=0)
out = interpolate.splev(unew, tck)
new_pts += to_shape_points(np.array(out).T)[:-1]
if len(new_pts) != num_points:
raise AssertionError(
f"Generated the wrong number of points: {len(new_pts)} != {num_points}")
return sort_by_angle(new_pts)
@dataclasses.dataclass
class ParametrizedShape:
"""Describes a target shape from the parameter set."""
r0: float # Where to put the center along the radial axis.
z0: float # Where to put the center along the vertical axis.
kappa: float # Elongation of the shape. (0.8, 3)
delta: float # Triangulation of the shape. (-1, 1)
radius: float # Radius of the shape (0.22, 2.58)
lambda_: float # Squareness of the shape. Recommend (0, 0)
side: ShapeSide # Whether and which side to shift the shape to.
@classmethod
def uniform_random_shape(
cls,
r_bounds=(0.8, 0.9),
z_bounds=(0, 0.2),
kappa_bounds=(1.0, 1.8), # elongation
delta_bounds=(-0.5, 0.6), # triangulation
radius_bounds=(tcv_common.LIMITER_WIDTH / 2 - 0.04,
tcv_common.LIMITER_WIDTH / 2),
lambda_bounds=(0, 0), # squareness
side=(ShapeSide.LEFT, ShapeSide.RIGHT)):
"""Return a random shape."""
return cls(
r0=np.random.uniform(*r_bounds),
z0=np.random.uniform(*z_bounds),
kappa=np.random.uniform(*kappa_bounds),
delta=np.random.uniform(*delta_bounds),
radius=np.random.uniform(*radius_bounds),
lambda_=np.random.uniform(*lambda_bounds),
side=side if isinstance(side, ShapeSide) else random.choice(side))
def gen_points(self, num_points: int) -> Tuple[ShapePoints, Point]:
"""Generates a set of shape points, return (points, modified (r0, z0))."""
r0 = self.r0
z0 = self.z0
num_warped_points = 32
points = np.zeros((num_warped_points, 2))
theta = evenly_spaced_angles(num_warped_points)
points[:, 0] = r0 + self.radius * np.cos(theta + self.delta * np.sin(theta)
- self.lambda_ * np.sin(2 * theta))
points[:, 1] = z0 + self.radius * self.kappa * np.sin(theta)
if self.side == ShapeSide.LEFT:
wall_shift = np.min(points[:, 0]) - tcv_common.INNER_LIMITER_R
points[:, 0] -= wall_shift
r0 -= wall_shift
elif self.side == ShapeSide.RIGHT:
wall_shift = np.max(points[:, 0]) - tcv_common.OUTER_LIMITER_R
points[:, 0] -= wall_shift
r0 -= wall_shift
return (spline_interpolate_points(to_shape_points(points), num_points),
Point(r0, z0))
def trim_zero_points(points: ShapePoints) -> Optional[ShapePoints]:
trimmed = [p for p in points if p.r != 0]
return trimmed if trimmed else None
class Diverted(enum.Enum):
"""Whether a shape is diverted or not."""
ANY = 0
LIMITED = 1
DIVERTED = 2
@classmethod
def from_refs(cls, references: named_array.NamedArray) -> "Diverted":
diverted = (references["diverted", 0] == 1)
limited = (references["limited", 0] == 1)
if diverted and limited:
raise ValueError("Diverted and limited doesn't make sense.")
if diverted:
return cls.DIVERTED
if limited:
return cls.LIMITED
return cls.ANY
@dataclasses.dataclass
class Shape:
"""Full specification of a shape."""
params: Optional[ParametrizedShape] = None
points: Optional[ShapePoints] = None
x_points: Optional[ShapePoints] = None
legs: Optional[ShapePoints] = None
diverted: Diverted = Diverted.ANY
ip: Optional[float] = None
limit_point: Optional[Point] = None
@classmethod
def from_references(cls, references: named_array.NamedArray) -> "Shape":
"""Extract a Shape from the references."""
if any(np.any(references[name] != 0)
for name in ("R", "Z", "kappa", "delta", "radius", "lambda")):
params = ParametrizedShape(
r0=references["R"][0],
z0=references["Z"][0],
kappa=references["kappa"][0],
delta=references["delta"][0],
radius=references["radius"][0],
lambda_=references["lambda"][0],
side=ShapeSide.NOSHIFT)
else:
params = None
ip = references["Ip", 0]
return cls(
params,
points=trim_zero_points(points_from_references(references, "shape1")),
x_points=trim_zero_points(points_from_references(references,
"x_points")),
legs=trim_zero_points(points_from_references(references, "legs")),
limit_point=trim_zero_points(points_from_references(
references, "limit_point")[0:1]),
diverted=Diverted.from_refs(references),
ip=float(ip) if ip != 0 else None)
def gen_references(self) -> named_array.NamedArray:
"""Return the references for the parametrized shape."""
refs = tcv_common.REF_RANGES.new_named_array()
if self.ip is not None:
refs["Ip", 0] = self.ip
refs["diverted", 0] = 1 if self.diverted == Diverted.DIVERTED else 0
refs["limited", 0] = 1 if self.diverted == Diverted.LIMITED else 0
if self.params is not None:
refs["R", 0] = self.params.r0
refs["Z", 0] = self.params.z0
refs["kappa", 0] = self.params.kappa
refs["delta", 0] = self.params.delta
refs["radius", 0] = self.params.radius
refs["lambda", 0] = self.params.lambda_
if self.points is not None:
points = np.array(self.points)
assert refs.names.count("shape_r") >= points.shape[0]
refs["shape_r", :points.shape[0]] = points[:, 0]
refs["shape_z", :points.shape[0]] = points[:, 1]
if self.x_points is not None:
x_points = np.array(self.x_points)
assert refs.names.count("x_points_r") >= x_points.shape[0]
refs["x_points_r", :x_points.shape[0]] = x_points[:, 0]
refs["x_points_z", :x_points.shape[0]] = x_points[:, 1]
if self.legs is not None:
legs = np.array(self.legs)
assert refs.names.count("legs_r") >= legs.shape[0]
refs["legs_r", :legs.shape[0]] = legs[:, 0]
refs["legs_z", :legs.shape[0]] = legs[:, 1]
if self.limit_point is not None:
refs["limit_point_r", 0] = self.limit_point.r
refs["limit_point_z", 0] = self.limit_point.z
return refs
def canonical(self) -> "Shape":
"""Return a canonical shape with a fixed number of points and params."""
num_points = tcv_common.REF_RANGES.count("shape_r")
out = copy.deepcopy(self)
if out.points is None:
if out.params is None:
raise ValueError("Can't canonicalize with no params or points.")
out.points, center = out.params.gen_points(num_points)
out.params.r0 = center.r
out.params.z0 = center.z
out.params.side = ShapeSide.NOSHIFT
else:
out.points = spline_interpolate_points(
out.points, num_points, out.x_points or [])
if out.params:
out.params.side = ShapeSide.NOSHIFT
else:
# Copied from FGE. Details: https://doi.org/10.1017/S0022377815001270
top = max(out.points, key=lambda p: p.z)
left = min(out.points, key=lambda p: p.r)
right = max(out.points, key=lambda p: p.r)
bottom = min(out.points, key=lambda p: p.z)
center = Point((left.r + right.r) / 2,
(top.z + bottom.z) / 2)
radius = (right.r - left.r) / 2
kappa = (top.z - bottom.z) / (right.r - left.r)
delta_lower = (center.r - bottom.r) / radius # upper triangularitiy
delta_upper = (center.r - top.r) / radius # lower triangularity
delta = (delta_lower + delta_upper) / 2
out.params = ParametrizedShape(
r0=center.r, z0=center.z, radius=radius, kappa=kappa, delta=delta,
lambda_=0, side=ShapeSide.NOSHIFT)
return out
def points_from_references(
references: named_array.NamedArray, key: str = "shape1",
num: Optional[int] = None) -> ShapePoints:
points = np.array([references[f"{key}_r"], references[f"{key}_z"]]).T
if num is not None:
points = points[:num]
return to_shape_points(points)
@dataclasses.dataclass
class ReferenceTimeSlice:
shape: Shape
time: float
hold: Optional[float] = None # Absolute time.
def __post_init__(self):
if self.hold is None:
self.hold = self.time
def canonicalize_reference_series(
time_slices: List[ReferenceTimeSlice]) -> List[ReferenceTimeSlice]:
"""Canonicalize a full sequence of time slices."""
outputs = []
for ref in time_slices:
ref_shape = ref.shape.canonical()
prev = outputs[-1] if outputs else None
if prev is not None and prev.hold + tcv_common.DT < ref.time:
leg_diff = len(ref_shape.legs or []) != len(prev.shape.legs or [])
xp_diff = len(ref_shape.x_points or []) != len(prev.shape.x_points or [])
div_diff = ref_shape.diverted != prev.shape.diverted
limit_diff = (
bool(ref_shape.limit_point and ref_shape.limit_point.r > 0) !=
bool(prev.shape.limit_point and prev.shape.limit_point.r > 0))
if leg_diff or xp_diff or div_diff or limit_diff:
# Try not to interpolate between a real x-point and a non-existent
# x-point. Non-existent x-points are represented as being at the
# origin, i.e. out to the left of the vessel, and could be interpolated
# into place, but that's weird, so better to pop it into existence by
# adding an extra frame one before with the new shape targets.
# This doesn't handle the case of multiple points appearing/disappearing
# out of order, or of one moving while the other disappears.
outputs.append(
ReferenceTimeSlice(
time=ref.time - tcv_common.DT,
shape=Shape(
ip=ref_shape.ip,
params=ref_shape.params,
points=ref_shape.points,
x_points=(prev.shape.x_points
if xp_diff else ref_shape.x_points),
legs=(prev.shape.legs if leg_diff else ref_shape.legs),
limit_point=(prev.shape.limit_point
if limit_diff else ref_shape.limit_point),
diverted=(prev.shape.diverted
if div_diff else ref_shape.diverted))))
outputs.append(ReferenceTimeSlice(ref_shape, ref.time, ref.hold))
return outputs

View File

@@ -0,0 +1,68 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A set of known shapes."""
from fusion_tcv import shape
from fusion_tcv import tcv_common
SHAPE_70166_0450 = shape.Shape(
ip=-120000,
params=shape.ParametrizedShape(
r0=0.89,
z0=0.25,
kappa=1.4,
delta=0.25,
radius=0.25,
lambda_=0,
side=shape.ShapeSide.NOSHIFT),
limit_point=shape.Point(tcv_common.INNER_LIMITER_R, 0.25),
diverted=shape.Diverted.LIMITED)
SHAPE_70166_0872 = shape.Shape(
ip=-110000,
params=shape.ParametrizedShape(
r0=0.8796,
z0=0.2339,
kappa=1.2441,
delta=0.2567,
radius=0.2390,
lambda_=0,
side=shape.ShapeSide.NOSHIFT,
),
points=[ # 20 points
shape.Point(0.6299, 0.1413),
shape.Point(0.6481, 0.0577),
shape.Point(0.6804, -0.0087),
shape.Point(0.7286, -0.0513),
shape.Point(0.7931, -0.0660),
shape.Point(0.8709, -0.0513),
shape.Point(0.9543, -0.0087),
shape.Point(1.0304, 0.0577),
shape.Point(1.0844, 0.1413),
shape.Point(1.1040, 0.2340),
shape.Point(1.0844, 0.3267),
shape.Point(1.0304, 0.4103),
shape.Point(0.9543, 0.4767),
shape.Point(0.8709, 0.5193),
shape.Point(0.7931, 0.5340),
shape.Point(0.7286, 0.5193),
shape.Point(0.6804, 0.4767),
shape.Point(0.6481, 0.4103),
shape.Point(0.6299, 0.3267),
shape.Point(0.6240, 0.2340),
],
limit_point=shape.Point(tcv_common.INNER_LIMITER_R, 0.2339),
diverted=shape.Diverted.LIMITED)

622
fusion_tcv/targets.py Normal file

File diff suppressed because it is too large Load Diff

273
fusion_tcv/tcv_common.py Normal file
View File

@@ -0,0 +1,273 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Constants and general tooling for TCV plant."""
import collections
from typing import Sequence, Text
from dm_env import specs
import numpy as np
from fusion_tcv import named_array
DT = 1e-4 # ie 10kHz
# Below are general input/output specifications used for controllers that are
# run on hardware. This interface corresponds to the so called "KH hybrid"
# controller specification that is used in various experiments by EPFL. Hence,
# this interface definition contains measurements and actions not used in our
# tasks.
# Number of actions the environment is exposing. Includes dummy (FAST) action.
NUM_ACTIONS = 20
# Number of actuated coils in sim (without the dummy action coil).
NUM_COILS_ACTUATED = 19
# Current and voltage limits by coil type
# Note this are the limits used in the environments and are different
# from the 'machine engineering limits' (as used/exposed by FGE).
# We apply a safety factor (=< 1.0) to the engineering limits.
CURRENT_SAFETY_FACTOR = 0.8
ENV_COIL_MAX_CURRENTS = collections.OrderedDict(
E=7500*CURRENT_SAFETY_FACTOR,
F=7500*CURRENT_SAFETY_FACTOR,
OH=26000*CURRENT_SAFETY_FACTOR,
DUMMY=2000*CURRENT_SAFETY_FACTOR,
G=2000*CURRENT_SAFETY_FACTOR)
# The g-coil has a saturation voltage that is tunable on a shot-by-shot basis.
# There is a deadband, where an action with absolute value of less than 8% of
# the saturation voltage is treated as zero.
ENV_G_COIL_SATURATION_VOLTAGE = 300
ENV_G_COIL_DEADBAND = ENV_G_COIL_SATURATION_VOLTAGE * 0.08
VOLTAGE_SAFETY_FACTOR = 1.0
ENV_COIL_MAX_VOLTAGE = collections.OrderedDict(
E=1400*VOLTAGE_SAFETY_FACTOR,
F=2200*VOLTAGE_SAFETY_FACTOR,
OH=1400*VOLTAGE_SAFETY_FACTOR,
DUMMY=400*VOLTAGE_SAFETY_FACTOR,
# This value is also used to clip values for the internal controller,
# and also to set the deadband voltage.
G=ENV_G_COIL_SATURATION_VOLTAGE)
# Ordered actions send by a controller to the TCV.
TCV_ACTIONS = (
'E_001', 'E_002', 'E_003', 'E_004', 'E_005', 'E_006', 'E_007', 'E_008',
'F_001', 'F_002', 'F_003', 'F_004', 'F_005', 'F_006', 'F_007', 'F_008',
'OH_001', 'OH_002',
'DUMMY_001', # GAS, ignored by TCV.
'G_001' # FAST
)
TCV_ACTION_INDICES = {n: i for i, n in enumerate(TCV_ACTIONS)}
TCV_ACTION_TYPES = collections.OrderedDict(
E=8,
F=8,
OH=2,
DUMMY=1,
G=1,
)
# Map the TCV actions to ranges of indices in the array.
TCV_ACTION_RANGES = named_array.NamedRanges(TCV_ACTION_TYPES)
# The voltages seem not to be centered at 0, but instead near these values:
TCV_ACTION_OFFSETS = {
'E_001': 6.79,
'E_002': -10.40,
'E_003': -1.45,
'E_004': 0.18,
'E_005': 11.36,
'E_006': -0.95,
'E_007': -4.28,
'E_008': 44.22,
'F_001': 38.49,
'F_002': -2.94,
'F_003': 5.58,
'F_004': 1.09,
'F_005': -36.63,
'F_006': -9.18,
'F_007': 5.34,
'F_008': 10.53,
'OH_001': -53.63,
'OH_002': -14.76,
}
TCV_ACTION_DELAYS = {
'E': [0.0005] * 8,
'F': [0.0005] * 8,
'OH': [0.0005] * 2,
'G': [0.0001],
}
# Ordered measurements and their dimensions from to the TCV controller specs.
TCV_MEASUREMENTS = collections.OrderedDict(
clint_vloop=1, # Flux loop 1
clint_rvloop=37, # Difference of flux between loops 2-38 and flux loop 1
bm=38, # Magnetic field probes
IE=8, # E-coil currents
IF=8, # F-coil currents
IOH=2, # OH-coil currents
Bdot=20, # Selection of 20 time-derivatives of magnetic field probes (bm).
DIOH=1, # OH-coil currents difference: OH(0) - OH(1).
FIR_FRINGE=1, # Not used, ignore.
IG=1, # G-coil current
ONEMM=1, # Not used, ignore
vloop=1, # Flux loop 1 derivative
IPHI=1, # Current through the Toroidal Field coils. Constant. Ignore.
)
NUM_MEASUREMENTS = sum(TCV_MEASUREMENTS.values())
# map the TCV measurements to ranges of indices in the array
TCV_MEASUREMENT_RANGES = named_array.NamedRanges(TCV_MEASUREMENTS)
# Several of the measurement probes for the rvloops are broken. Add an extra key
# that allows us to only grab the usable ones
BROKEN_RVLOOP_IDXS = [9, 10, 11]
TCV_MEASUREMENT_RANGES.set_range('clint_rvloop_usable', [
idx for i, idx in enumerate(TCV_MEASUREMENT_RANGES['clint_rvloop'])
if i not in BROKEN_RVLOOP_IDXS])
TCV_COIL_CURRENTS_INDEX = [
*TCV_MEASUREMENT_RANGES['IE'],
*TCV_MEASUREMENT_RANGES['IF'],
*TCV_MEASUREMENT_RANGES['IOH'],
*TCV_MEASUREMENT_RANGES['IPHI'], # In place of DUMMY.
*TCV_MEASUREMENT_RANGES['IG'],
]
# References for what we want the agent to accomplish.
REF_RANGES = named_array.NamedRanges({
'R': 2,
'Z': 2,
'Ip': 2,
'kappa': 2,
'delta': 2,
'radius': 2,
'lambda': 2,
'diverted': 2, # bool, must be diverted
'limited': 2, # bool, must be limited
'shape_r': 32,
'shape_z': 32,
'x_points_r': 8,
'x_points_z': 8,
'legs_r': 16, # Use for diverted/snowflake
'legs_z': 16,
'limit_point_r': 2,
'limit_point_z': 2,
})
# Environments should use a consistent datatype for interacting with agents.
ENVIRONMENT_DATA_TYPE = np.float64
def observation_spec():
"""Observation spec for all TCV environments."""
return {
'references':
specs.Array(
shape=(REF_RANGES.size,),
dtype=ENVIRONMENT_DATA_TYPE,
name='references'),
'measurements':
specs.Array(
shape=(TCV_MEASUREMENT_RANGES.size,),
dtype=ENVIRONMENT_DATA_TYPE,
name='measurements'),
'last_action':
specs.Array(
shape=(TCV_ACTION_RANGES.size,),
dtype=ENVIRONMENT_DATA_TYPE,
name='last_action'),
}
def measurements_to_dict(measurements):
"""Converts a single measurement vector or a time series to a dict.
Args:
measurements: A single measurement of size `NUM_MEASUREMENTS` or a time
series, where the batch dimension is last, shape: (NUM_MEASUREMENTS, t).
Returns:
A dict mapping keys `TCV_MEASUREMENTS` to the corresponding measurements.
"""
assert measurements.shape[0] == NUM_MEASUREMENTS
measurements_dict = collections.OrderedDict()
index = 0
for key, dim in TCV_MEASUREMENTS.items():
measurements_dict[key] = measurements[index:index + dim, ...]
index += dim
return measurements_dict
def dict_to_measurement(measurement_dict):
"""Converts a single measurement dict to a vector or time series.
Args:
measurement_dict: A dict with the measurement keys containing np arrays of
size (meas_size, ...). The inner sizes all have to be the same.
Returns:
An array of size (num_measurements, ...)
"""
assert len(measurement_dict) == len(TCV_MEASUREMENTS)
# Grab the shape of the first array.
shape = measurement_dict['clint_vloop'].shape
out_shape = list(shape)
out_shape[0] = NUM_MEASUREMENTS
out_shape = tuple(out_shape)
measurements = np.zeros((out_shape))
index = 0
for key, dim in TCV_MEASUREMENTS.items():
dim = TCV_MEASUREMENTS[key]
measurements[index:index + dim, ...] = measurement_dict[key]
index += dim
return measurements
def action_spec():
return get_coil_spec(TCV_ACTIONS, ENV_COIL_MAX_VOLTAGE, ENVIRONMENT_DATA_TYPE)
def get_coil_spec(coil_names: Sequence[Text],
spec_mapping,
dtype=ENVIRONMENT_DATA_TYPE) -> specs.BoundedArray:
"""Maps specs indexed by coil type to coils given their type."""
coil_max, coil_min = [], []
for name in coil_names:
# Coils names are <coil_type>_<coil_number>
coil_type, _ = name.split('_')
coil_max.append(spec_mapping[coil_type])
coil_min.append(-spec_mapping[coil_type])
return specs.BoundedArray(
shape=(len(coil_names),), dtype=dtype, minimum=coil_min, maximum=coil_max)
INNER_LIMITER_R = 0.62400001
OUTER_LIMITER_R = 1.14179182
LIMITER_WIDTH = OUTER_LIMITER_R - INNER_LIMITER_R
LIMITER_RADIUS = LIMITER_WIDTH / 2
VESSEL_CENTER_R = INNER_LIMITER_R + LIMITER_RADIUS

102
fusion_tcv/terminations.py Normal file
View File

@@ -0,0 +1,102 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Terminations for the fusion environment."""
import abc
from typing import List, Optional
import numpy as np
from fusion_tcv import fge_state
from fusion_tcv import tcv_common
class Abstract(abc.ABC):
"""Abstract reward class."""
@abc.abstractmethod
def terminate(self, state: fge_state.FGEState) -> Optional[str]:
"""Returns a reason if the situation should be considered a termination."""
class CoilCurrentSaturation(Abstract):
"""Terminates if the coils have saturated their current."""
def terminate(self, state: fge_state.FGEState) -> Optional[str]:
# Coil currents are checked by type, independent of the order.
for coil_type, max_current in tcv_common.ENV_COIL_MAX_CURRENTS.items():
if coil_type == "DUMMY":
continue
currents = state.get_coil_currents_by_type(coil_type)
if (np.abs(currents) > max_current).any():
return (f"CoilCurrentSaturation: {coil_type}, max: {max_current}, "
"real: " + ", ".join(f"{c:.1f}" for c in currents))
return None
class OHTooDifferent(Abstract):
"""Terminates if the coil currents are too far apart from one another."""
def __init__(self, max_diff: float):
self._max_diff = max_diff
def terminate(self, state: fge_state.FGEState) -> Optional[str]:
oh_coil_currents = state.get_coil_currents_by_type("OH")
assert len(oh_coil_currents) == 2
oh_current_abs = abs(oh_coil_currents[0] - oh_coil_currents[1])
if oh_current_abs > self._max_diff:
return ("OHTooDifferent: currents: "
f"({oh_coil_currents[0]:.0f}, {oh_coil_currents[1]:.0f}), "
f"diff: {oh_current_abs:.0f}, max: {self._max_diff}")
return None
class IPTooLow(Abstract):
"""Terminates if the magnitude of Ip in any component is too low."""
def __init__(self, singlet_threshold: float, droplet_threshold: float):
self._singlet_threshold = singlet_threshold
self._droplet_threshold = droplet_threshold
def terminate(self, state: fge_state.FGEState) -> Optional[str]:
_, _, ip_d = state.rzip_d
if len(ip_d) == 1:
if ip_d[0] > self._singlet_threshold: # Sign due to negative Ip.
return f"IPTooLow: Singlet, {ip_d[0]:.0f}"
return None
else:
if max(ip_d) > self._droplet_threshold: # Sign due to negative Ip.
return f"IPTooLow: Components: {ip_d[0]:.0f}, {ip_d[1]:.0f}"
return None
class AnyTermination(Abstract):
"""Terminates if any of conditions are met."""
def __init__(self, terminators: List[Abstract]):
self._terminators = terminators
def terminate(self, state: fge_state.FGEState) -> Optional[str]:
for terminator in self._terminators:
term = terminator.terminate(state)
if term:
return term
return None
CURRENT_OH_IP = AnyTermination([
CoilCurrentSaturation(),
OHTooDifferent(max_diff=4000),
IPTooLow(singlet_threshold=-60000, droplet_threshold=-25000),
])

40
fusion_tcv/trajectory.py Normal file
View File

@@ -0,0 +1,40 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A trajectory for an episode."""
from typing import List
import dataclasses
import numpy as np
@dataclasses.dataclass
class Trajectory:
"""A trajectory of actions/obs for an episode."""
measurements: np.ndarray
references: np.ndarray
reward: np.ndarray
actions: np.ndarray
@classmethod
def stack(cls, series: List["Trajectory"]) -> "Trajectory":
"""Stack a series of trajectories, adding a trailing time dimension."""
values = {k: np.empty(v.shape + (len(series),))
for k, v in dataclasses.asdict(series[0]).items()
if v is not None}
for i, ts in enumerate(series):
for k, v in values.items():
v[..., i] = getattr(ts, k)
out = cls(**values)
return out

183
fusion_tcv/transforms.py Normal file
View File

@@ -0,0 +1,183 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transforms from actual/target to rewards."""
import abc
import math
from typing import List, Optional
import dataclasses
# Comparison of some of the transforms:
# Order is NegExp, SoftPlus, Sigmoid.
# Over range of good to bad:
# https://www.wolframalpha.com/input/?i=plot+e%5E%28x*ln%280.1%29%29%2C2%2F%281%2Be%5E%28x*ln%2819%29%29%29%2C1%2F%281%2Be%5E%28-+%28-ln%2819%29+-+%28x-1%29*%282*ln%2819%29%29%29%29%29+from+x%3D0+to+2
# When close to good:
# https://www.wolframalpha.com/input/?i=plot+e%5E%28x*ln%280.1%29%29%2C2%2F%281%2Be%5E%28x*ln%2819%29%29%29%2C1%2F%281%2Be%5E%28-+%28-ln%2819%29+-+%28x-1%29*%282*ln%2819%29%29%29%29%29+from+x%3D0+to+0.2
class AbstractTransform(abc.ABC):
@abc.abstractmethod
def __call__(self, errors: List[float]) -> List[float]:
"""Transforms target errors into rewards."""
@property
def outputs(self) -> Optional[int]:
return None
def clip(value: float, low: float, high: float) -> float:
"""Clip a value to the range of low - high."""
if math.isnan(value):
return value
assert low <= high
return max(low, min(high, value))
def scale(v: float, a: float, b: float, c: float, d: float) -> float:
"""Scale a value, v on a line with anchor points a,b to new anchors c,d."""
v01 = (v - a) / (b - a)
return c - v01 * (c - d)
def logistic(v: float) -> float:
"""Standard logistic, asymptoting to 0 and 1."""
v = clip(v, -50, 50) # Improve numerical stability.
return 1 / (1 + math.exp(-v))
@dataclasses.dataclass(frozen=True)
class Equal(AbstractTransform):
"""Returns 1 if the error is 0 and not_equal_val otherwise."""
not_equal_val: float = 0
def __call__(self, errors: List[float]) -> List[float]:
out = []
for err in errors:
if math.isnan(err):
out.append(err)
elif err == 0:
out.append(1)
else:
out.append(self.not_equal_val)
return out
class Abs(AbstractTransform):
"""Take the absolue value of the error. Does not guarantee 0-1."""
@staticmethod
def __call__(errors: List[float]) -> List[float]:
return [abs(err) for err in errors]
class Neg(AbstractTransform):
"""Negate the error. Does not guarantee 0-1."""
@staticmethod
def __call__(errors: List[float]) -> List[float]:
return [-err for err in errors]
@dataclasses.dataclass(frozen=True)
class Pow(AbstractTransform):
"""Return a power of the error. Does not guarantee 0-1."""
pow: float
def __call__(self, errors: List[float]) -> List[float]:
return [err**self.pow for err in errors]
@dataclasses.dataclass(frozen=True)
class Log(AbstractTransform):
"""Return a log of the error. Does not guarantee 0-1."""
eps: float = 1e-4
def __call__(self, errors: List[float]) -> List[float]:
return [math.log(err + self.eps) for err in errors]
@dataclasses.dataclass(frozen=True)
class ClippedLinear(AbstractTransform):
"""Scales and clips errors, bad to 0, good to 1. If good=0, this is a relu."""
bad: float
good: float = 0
def __call__(self, errors: List[float]) -> List[float]:
return [clip(scale(err, self.bad, self.good, 0, 1), 0, 1)
for err in errors]
@dataclasses.dataclass(frozen=True)
class SoftPlus(AbstractTransform):
"""Scales and clips errors, bad to 0.1, good to 1, asymptoting to 0.
Based on the lower half of the logistic instead of the standard softplus as
we want it to be bounded from 0 to 1, with the good value being exactly 1.
Various constants can be chosen to get the softplus to give the desired
properties, but this is much simpler.
"""
bad: float
good: float = 0
# Constant to set the sharpness/slope of the softplus.
# Default was chosen such that the good/bad have 1 and 0.1 reward:
# https://www.wolframalpha.com/input/?i=plot+2%2F%281%2Be%5E%28x*ln%2819%29%29%29+from+x%3D0+to+2
low: float = -math.log(19) # -2.9444389791664403
def __call__(self, errors: List[float]) -> List[float]:
return [clip(2 * logistic(scale(e, self.bad, self.good, self.low, 0)), 0, 1)
for e in errors]
@dataclasses.dataclass(frozen=True)
class NegExp(AbstractTransform):
"""Scales and clips errors, bad to 0.1, good to 1, asymptoting to 0.
This scales the reward in an exponential space. This means there is a sharp
gradient toward reaching the value of good, flattening out at the value of
bad. This can be useful for a reward that gives meaningful signal far away,
but still have a sharp gradient near the true target.
"""
bad: float
good: float = 0
# Constant to set the sharpness/slope of the exponential.
# Default was chosen such that the good/bad have 1 and 0.1 reward:
# https://www.wolframalpha.com/input/?i=plot+e%5E%28x*ln%280.1%29%29+from+x%3D0+to+2
low: float = -math.log(0.1)
def __call__(self, errors: List[float]) -> List[float]:
return [clip(math.exp(-scale(e, self.bad, self.good, self.low, 0)), 0, 1)
for e in errors]
@dataclasses.dataclass(frozen=True)
class Sigmoid(AbstractTransform):
"""Scales and clips errors, bad to 0.05, good to 0.95, asymptoting to 0-1."""
good: float
bad: float
# Constants to set the sharpness/slope of the sigmoid.
# Defaults were chosen such that the good/bad have 0.95 and 0.05 reward:
# https://www.wolframalpha.com/input/?i=plot+1%2F%281%2Be%5E%28-+%28-ln%2819%29+-+%28x-1%29*%282*ln%2819%29%29%29%29%29+from+x%3D0+to+2
high: float = math.log(19) # +2.9444389791664403
low: float = -math.log(19) # -2.9444389791664403
def __call__(self, errors: List[float]) -> List[float]:
return [logistic(scale(err, self.bad, self.good, self.low, self.high))
for err in errors]

View File

@@ -0,0 +1,162 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for transforms."""
import math
from absl.testing import absltest
from fusion_tcv import transforms
NAN = float("nan")
class TransformsTest(absltest.TestCase):
def assertNan(self, value: float):
self.assertTrue(math.isnan(value))
def test_clip(self):
self.assertEqual(transforms.clip(-1, 0, 1), 0)
self.assertEqual(transforms.clip(5, 0, 1), 1)
self.assertEqual(transforms.clip(0.5, 0, 1), 0.5)
self.assertNan(transforms.clip(NAN, 0, 1))
def test_scale(self):
self.assertEqual(transforms.scale(0, 0, 0.5, 0, 1), 0)
self.assertEqual(transforms.scale(0.125, 0, 0.5, 0, 1), 0.25)
self.assertEqual(transforms.scale(0.25, 0, 0.5, 0, 1), 0.5)
self.assertEqual(transforms.scale(0.5, 0, 0.5, 0, 1), 1)
self.assertEqual(transforms.scale(1, 0, 0.5, 0, 1), 2)
self.assertEqual(transforms.scale(-1, 0, 0.5, 0, 1), -2)
self.assertEqual(transforms.scale(0.5, 1, 0, 0, 1), 0.5)
self.assertEqual(transforms.scale(0.25, 1, 0, 0, 1), 0.75)
self.assertEqual(transforms.scale(0, 0, 1, -4, 4), -4)
self.assertEqual(transforms.scale(0.25, 0, 1, -4, 4), -2)
self.assertEqual(transforms.scale(0.5, 0, 1, -4, 4), 0)
self.assertEqual(transforms.scale(0.75, 0, 1, -4, 4), 2)
self.assertEqual(transforms.scale(1, 0, 1, -4, 4), 4)
self.assertNan(transforms.scale(NAN, 0, 1, -4, 4))
def test_logistic(self):
self.assertLess(transforms.logistic(-50), 0.000001)
self.assertLess(transforms.logistic(-5), 0.01)
self.assertEqual(transforms.logistic(0), 0.5)
self.assertGreater(transforms.logistic(5), 0.99)
self.assertGreater(transforms.logistic(50), 0.999999)
self.assertAlmostEqual(transforms.logistic(0.8), math.tanh(0.4) / 2 + 0.5)
self.assertNan(transforms.logistic(NAN))
def test_exp_scaled(self):
t = transforms.NegExp(good=0, bad=1)
self.assertNan(t([NAN])[0])
self.assertAlmostEqual(t([0])[0], 1)
self.assertAlmostEqual(t([1])[0], 0.1)
self.assertLess(t([50])[0], 0.000001)
t = transforms.NegExp(good=10, bad=30)
self.assertAlmostEqual(t([0])[0], 1)
self.assertAlmostEqual(t([10])[0], 1)
self.assertLess(t([3000])[0], 0.000001)
t = transforms.NegExp(good=30, bad=10)
self.assertAlmostEqual(t([50])[0], 1)
self.assertAlmostEqual(t([30])[0], 1)
self.assertAlmostEqual(t([10])[0], 0.1)
self.assertLess(t([-90])[0], 0.00001)
def test_neg(self):
t = transforms.Neg()
self.assertEqual(t([-5, -3, 0, 1, 4]), [5, 3, 0, -1, -4])
self.assertNan(t([NAN])[0])
def test_abs(self):
t = transforms.Abs()
self.assertEqual(t([-5, -3, 0, 1, 4]), [5, 3, 0, 1, 4])
self.assertNan(t([NAN])[0])
def test_pow(self):
t = transforms.Pow(2)
self.assertEqual(t([-5, -3, 0, 1, 4]), [25, 9, 0, 1, 16])
self.assertNan(t([NAN])[0])
def test_log(self):
t = transforms.Log()
self.assertAlmostEqual(t([math.exp(2)])[0], 2, 4) # Low precision from eps.
self.assertNan(t([NAN])[0])
def test_clipped_linear(self):
t = transforms.ClippedLinear(good=0.1, bad=0.3)
self.assertAlmostEqual(t([0])[0], 1)
self.assertAlmostEqual(t([0.05])[0], 1)
self.assertAlmostEqual(t([0.1])[0], 1)
self.assertAlmostEqual(t([0.15])[0], 0.75)
self.assertAlmostEqual(t([0.2])[0], 0.5)
self.assertAlmostEqual(t([0.25])[0], 0.25)
self.assertAlmostEqual(t([0.3])[0], 0)
self.assertAlmostEqual(t([0.4])[0], 0)
self.assertNan(t([NAN])[0])
t = transforms.ClippedLinear(good=1, bad=0.5)
self.assertAlmostEqual(t([1.5])[0], 1)
self.assertAlmostEqual(t([1])[0], 1)
self.assertAlmostEqual(t([0.75])[0], 0.5)
self.assertAlmostEqual(t([0.5])[0], 0)
self.assertAlmostEqual(t([0.25])[0], 0)
def test_softplus(self):
t = transforms.SoftPlus(good=0.1, bad=0.3)
self.assertEqual(t([0])[0], 1)
self.assertEqual(t([0.1])[0], 1)
self.assertAlmostEqual(t([0.3])[0], 0.1)
self.assertLess(t([0.5])[0], 0.01)
self.assertNan(t([NAN])[0])
t = transforms.SoftPlus(good=1, bad=0.5)
self.assertEqual(t([1.5])[0], 1)
self.assertEqual(t([1])[0], 1)
self.assertAlmostEqual(t([0.5])[0], 0.1)
self.assertLess(t([0.1])[0], 0.01)
def test_sigmoid(self):
t = transforms.Sigmoid(good=0.1, bad=0.3)
self.assertGreater(t([0])[0], 0.99)
self.assertAlmostEqual(t([0.1])[0], 0.95)
self.assertAlmostEqual(t([0.2])[0], 0.5)
self.assertAlmostEqual(t([0.3])[0], 0.05)
self.assertLess(t([0.4])[0], 0.01)
self.assertNan(t([NAN])[0])
t = transforms.Sigmoid(good=1, bad=0.5)
self.assertGreater(t([1.5])[0], 0.99)
self.assertAlmostEqual(t([1])[0], 0.95)
self.assertAlmostEqual(t([0.75])[0], 0.5)
self.assertAlmostEqual(t([0.5])[0], 0.05)
self.assertLess(t([0.25])[0], 0.01)
def test_equal(self):
t = transforms.Equal()
self.assertEqual(t([0])[0], 1)
self.assertEqual(t([0.001])[0], 0)
self.assertNan(t([NAN])[0])
t = transforms.Equal(not_equal_val=0.5)
self.assertEqual(t([0])[0], 1)
self.assertEqual(t([0.001])[0], 0.5)
if __name__ == "__main__":
absltest.main()