mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-29 19:55:25 +08:00
Release of IODINE
PiperOrigin-RevId: 299101887
This commit is contained in:
@@ -10,6 +10,7 @@ env:
|
|||||||
matrix:
|
matrix:
|
||||||
- PROJECT="tvt"
|
- PROJECT="tvt"
|
||||||
- PROJECT="cs_gan"
|
- PROJECT="cs_gan"
|
||||||
|
- PROJECT="iodine"
|
||||||
- PROJECT="transporter"
|
- PROJECT="transporter"
|
||||||
before_script:
|
before_script:
|
||||||
- sudo apt-get update -qq
|
- sudo apt-get update -qq
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
|
|||||||
|
|
||||||
## Projects
|
## Projects
|
||||||
|
|
||||||
|
* [Multi-Object Representation Learning with Iterative Variational Inference (IODINE)](iodine)
|
||||||
* [AlphaFold CASP13](alphafold_casp13), Nature 2020
|
* [AlphaFold CASP13](alphafold_casp13), Nature 2020
|
||||||
* [Unrestricted Adversarial Challenge](unrestricted_advx)
|
* [Unrestricted Adversarial Challenge](unrestricted_advx)
|
||||||
* [Hierarchical Probabilistic U-Net (HPU-Net)](hierarchical_probabilistic_unet)
|
* [Hierarchical Probabilistic U-Net (HPU-Net)](hierarchical_probabilistic_unet)
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,142 @@
|
|||||||
|
# IODINE
|
||||||
|
Reference implementation for the paper ["Multi-Object Representation Learning with Iterative Variational Inference"](https://arxiv.org/abs/1903.00450).
|
||||||
|
This repository contains:
|
||||||
|
|
||||||
|
* An IODINE implementation in Tensorflow v1.
|
||||||
|
* Configurations used in the paper (checkpoints available in Cloud Storage) for:
|
||||||
|
* CLEVR
|
||||||
|
* Multi-dSprites
|
||||||
|
* Tetrominoes
|
||||||
|
* A notebook for running and inspecting the model and plotting the results
|
||||||
|
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
1. Clone the DeepMind research repository:
|
||||||
|
|
||||||
|
``` bash
|
||||||
|
git clone https://github.com/deepmind/deepmind-research.git
|
||||||
|
cd deepmind-research
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Download the checkpoints from GCP. A shell script is provided:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./iodine/download_checkpoints.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
On platforms without wget, the files can be downloaded from [this webpage](https://console.cloud.google.com/storage/browser/deepmind-research-iodine?pli=1)
|
||||||
|
and the unzipped `checkpoints/` folder should be placed in
|
||||||
|
`deepmind-research/iodine/checkpoints`.
|
||||||
|
|
||||||
|
|
||||||
|
3. Prepare a Python 3 environment - virtualenv is recommended.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m venv iodine_venv
|
||||||
|
source iodine_venv/bin/activate
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Install dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip3 install -r iodine/requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
5. The `multi_object_datasets` package installed via requirements.txt provides python code to open the data files, but not the data files themselves.
|
||||||
|
Download the desired datasets either manually from the [Google Cloud Storage](https://console.cloud.google.com/storage/browser/multi-object-datasets) or using the commands below:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pushd iodine/multi_object_datasets
|
||||||
|
# CLEVR
|
||||||
|
wget https://storage.googleapis.com/multi-object-datasets/clevr_with_masks/clevr_with_masks_train.tfrecords
|
||||||
|
# Multi-dSprites
|
||||||
|
wget https://storage.googleapis.com/multi-object-datasets/multi_dsprites/multi_dsprites_colored_on_grayscale.tfrecords
|
||||||
|
# Tetrominoes
|
||||||
|
wget https://storage.googleapis.com/multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords
|
||||||
|
# Get back to location containing 'iodine' directory
|
||||||
|
popd
|
||||||
|
```
|
||||||
|
|
||||||
|
See [multi_object_datasets repository](https://github.com/deepmind/multi_object_datasets)
|
||||||
|
for further details.
|
||||||
|
6. Make sure that you have CUDA 10 and CuDNN 7 installed
|
||||||
|
|
||||||
|
|
||||||
|
## Interact with a Model
|
||||||
|
Use the jupyter notebook `Eval.ipynb` to load and run one of the checkpoints.
|
||||||
|
It also contains code to plot the outputs and latent traversals.
|
||||||
|
|
||||||
|
|
||||||
|
## Train a Model
|
||||||
|
To train your own model use the [Sacred](https://github.com/IDSIA/sacred) experiment defined in `main.py`.
|
||||||
|
The configurations used in the paper for the different datasets are available as [named configs](https://sacred.readthedocs.io/en/latest/configuration.html#named-configurations) inside of `configuration.py`.
|
||||||
|
### Train a new model
|
||||||
|
* CLEVR6
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m iodine.main -f with clevr6
|
||||||
|
```
|
||||||
|
|
||||||
|
* Multi-dSprites
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m iodine.main -f with multi_dsprites
|
||||||
|
```
|
||||||
|
|
||||||
|
* Tetrominoes
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m iodine.main -f with tetrominoes
|
||||||
|
```
|
||||||
|
|
||||||
|
It is recommended to add an observer to your run to let Sacred record the details of run.
|
||||||
|
To add a [FileStorageObserver](https://sacred.readthedocs.io/en/latest/command_line.html#filestorage-observer) add `-F my_storage_dir`, and add `-m my_db_name` for a [MongoObserver](https://sacred.readthedocs.io/en/latest/command_line.html#mongodb-observer).
|
||||||
|
|
||||||
|
### Adjusting Config Values
|
||||||
|
The experiment has a configuration that can be printed and adjusted from the commandline. E.g.:
|
||||||
|
|
||||||
|
``` bash
|
||||||
|
# print configuration
|
||||||
|
python3 -m iodine.main -f print_config with clevr6
|
||||||
|
# run experiment after adjusting batch_size and the size of the shuffle buffer
|
||||||
|
python3 -m iodine.main -f with clevr6 batch_size=2 data.shuffle_buffer=100
|
||||||
|
```
|
||||||
|
|
||||||
|
### Tensorboard
|
||||||
|
Each run stores checkpoints and summaries in the directory specified by `checkpoint_dir`, to which a suffix based on the run_id is appended.
|
||||||
|
If an observer is added the `run_id` is set automatically. Otherwise it should be set manually using e.g. `run_id=5`.
|
||||||
|
|
||||||
|
Summaries can be viewed using tensorboard. E.g. like this for clevr6 (assuming `run_id=1`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tensorboard --log-dir iodine/checkpoints/clevr6_1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Continue Previous Run
|
||||||
|
To continue a previous run pass `continue_run=True` and the path of the checkpoints:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 -m iodine.main -f with clevr6 checkpoint_dir=iodine/checkpoints/clevr6_1
|
||||||
|
```
|
||||||
|
|
||||||
|
## Code Structure
|
||||||
|
The main experiment defined in `main.py` uses `sacred` and the configurations for the different datasets are added as named configs and can be found in `configuration.py`.
|
||||||
|
The model implementation can be found in the `modules` directory and is based on `tensorflow` and `sonnet`:
|
||||||
|
|
||||||
|
* `iodine.py` The main IODINE module that assembles the decoder, refinement network, distributions and factor regressor.
|
||||||
|
* `decoder.py` The ComponentDecoder which is a wrapper around networks that takes care of splitting the output channels into means and masks.
|
||||||
|
* `refinement.py` The refinement components assembles the encoder network, LSTM and refinement head.
|
||||||
|
* `networks.py` Different standard networks such as CNN, BroadcastCNN, and LSTM.
|
||||||
|
* `distribution.py` Definition of the latent and pixel distributions.
|
||||||
|
* `factor_eval.py` Contains the factor regressor which predicts the true factors from the inferred object latents.
|
||||||
|
* `data.py` Dataset wrappers around `multi_object_datasets` that take care of shuffling, batching and preprocessing.
|
||||||
|
* `plotting.py` Helper functions for plotting results.
|
||||||
|
* `utils.py` General helper functions.
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
**DISCLAIMER**
|
||||||
|
|
||||||
|
This is not an officially supported Google product.
|
||||||
|
|
||||||
|
---
|
||||||
@@ -0,0 +1,370 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Configurations for IODINE."""
|
||||||
|
# pylint: disable=missing-docstring, unused-variable
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
def clevr6():
|
||||||
|
n_z = 64 # number of latent dimensions
|
||||||
|
num_components = 7 # number of components (K)
|
||||||
|
num_iters = 5
|
||||||
|
checkpoint_dir = "iodine/checkpoints/clevr6"
|
||||||
|
|
||||||
|
# For the paper we used 8 GPUs with a batch size of 4 each.
|
||||||
|
# This means a total batch size of 32, which is too large for a single GPU.
|
||||||
|
# When reducing the batch size, the learning rate should also be lowered.
|
||||||
|
batch_size = 4
|
||||||
|
learn_rate = 0.001 * math.sqrt(batch_size / 32)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"constructor": "iodine.modules.data.CLEVR",
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"path": "multi_object_datasets/clevr_with_masks_train.tfrecords",
|
||||||
|
"max_num_objects": 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
model = {
|
||||||
|
"constructor": "iodine.modules.iodine.IODINE",
|
||||||
|
"n_z": n_z,
|
||||||
|
"num_components": num_components,
|
||||||
|
"num_iters": num_iters,
|
||||||
|
"iter_loss_weight": "linspace",
|
||||||
|
"coord_type": "linear",
|
||||||
|
"decoder": {
|
||||||
|
"constructor": "iodine.modules.decoder.ComponentDecoder",
|
||||||
|
"pixel_decoder": {
|
||||||
|
"constructor": "iodine.modules.networks.BroadcastConv",
|
||||||
|
"cnn_opt": {
|
||||||
|
# Final channels is irrelevant with target_output_shape
|
||||||
|
"output_channels": [64, 64, 64, 64, None],
|
||||||
|
"kernel_shapes": [3],
|
||||||
|
"strides": [1],
|
||||||
|
"activation": "elu",
|
||||||
|
},
|
||||||
|
"coord_type": "linear",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"refinement_core": {
|
||||||
|
"constructor": "iodine.modules.refinement.RefinementCore",
|
||||||
|
"encoder_net": {
|
||||||
|
"constructor": "iodine.modules.networks.CNN",
|
||||||
|
"mode": "avg_pool",
|
||||||
|
"cnn_opt": {
|
||||||
|
"output_channels": [64, 64, 64, 64],
|
||||||
|
"strides": [2],
|
||||||
|
"kernel_shapes": [3],
|
||||||
|
"activation": "elu",
|
||||||
|
},
|
||||||
|
"mlp_opt": {
|
||||||
|
"output_sizes": [256, 256],
|
||||||
|
"activation": "elu"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"recurrent_net": {
|
||||||
|
"constructor": "iodine.modules.networks.LSTM",
|
||||||
|
"hidden_sizes": [256],
|
||||||
|
},
|
||||||
|
"refinement_head": {
|
||||||
|
"constructor": "iodine.modules.refinement.ResHead"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"latent_dist": {
|
||||||
|
"constructor": "iodine.modules.distributions.LocScaleDistribution",
|
||||||
|
"dist": "normal",
|
||||||
|
"scale_act": "softplus",
|
||||||
|
"scale": "var",
|
||||||
|
"name": "latent_dist",
|
||||||
|
},
|
||||||
|
"output_dist": {
|
||||||
|
"constructor": "iodine.modules.distributions.MaskedMixture",
|
||||||
|
"num_components": num_components,
|
||||||
|
"component_dist": {
|
||||||
|
"constructor":
|
||||||
|
"iodine.modules.distributions.LocScaleDistribution",
|
||||||
|
"dist":
|
||||||
|
"logistic",
|
||||||
|
"scale":
|
||||||
|
"fixed",
|
||||||
|
"scale_val":
|
||||||
|
0.03,
|
||||||
|
"name":
|
||||||
|
"pixel_distribution",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"factor_evaluator": {
|
||||||
|
"constructor":
|
||||||
|
"iodine.modules.factor_eval.FactorRegressor",
|
||||||
|
"mapping": [
|
||||||
|
("color", 9, "categorical"),
|
||||||
|
("shape", 4, "categorical"),
|
||||||
|
("size", 3, "categorical"),
|
||||||
|
("position", 3, "scalar"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
optimizer = {
|
||||||
|
"constructor": "tensorflow.train.AdamOptimizer",
|
||||||
|
"learning_rate": {
|
||||||
|
"constructor": "tensorflow.train.exponential_decay",
|
||||||
|
"learning_rate": learn_rate,
|
||||||
|
"global_step": {
|
||||||
|
"constructor": "tensorflow.train.get_or_create_global_step"
|
||||||
|
},
|
||||||
|
"decay_steps": 1000000,
|
||||||
|
"decay_rate": 0.1,
|
||||||
|
},
|
||||||
|
"beta1": 0.95,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def multi_dsprites():
|
||||||
|
n_z = 16 # number of latent dimensions
|
||||||
|
num_components = 6 # number of components (K)
|
||||||
|
num_iters = 5
|
||||||
|
checkpoint_dir = "iodine/checkpoints/multi_dsprites"
|
||||||
|
|
||||||
|
# For the paper we used 8 GPUs with a batch size of 16 each.
|
||||||
|
# This means a total batch size of 128, which is too large for a single GPU.
|
||||||
|
# When reducing the batch size, the learning rate should also be lowered.
|
||||||
|
batch_size = 16
|
||||||
|
learn_rate = 0.0003 * math.sqrt(batch_size / 128)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"constructor":
|
||||||
|
"iodine.modules.data.MultiDSprites",
|
||||||
|
"batch_size":
|
||||||
|
batch_size,
|
||||||
|
"path":
|
||||||
|
"multi_object_datasets/multi_dsprites_colored_on_grayscale.tfrecords",
|
||||||
|
"dataset_variant":
|
||||||
|
"colored_on_grayscale",
|
||||||
|
"min_num_objs":
|
||||||
|
3,
|
||||||
|
"max_num_objs":
|
||||||
|
3,
|
||||||
|
}
|
||||||
|
|
||||||
|
model = {
|
||||||
|
"constructor": "iodine.modules.iodine.IODINE",
|
||||||
|
"n_z": n_z,
|
||||||
|
"num_components": num_components,
|
||||||
|
"num_iters": num_iters,
|
||||||
|
"iter_loss_weight": "linspace",
|
||||||
|
"coord_type": "cos",
|
||||||
|
"coord_freqs": 3,
|
||||||
|
"decoder": {
|
||||||
|
"constructor": "iodine.modules.decoder.ComponentDecoder",
|
||||||
|
"pixel_decoder": {
|
||||||
|
"constructor": "iodine.modules.networks.BroadcastConv",
|
||||||
|
"cnn_opt": {
|
||||||
|
# Final channels is irrelevant with target_output_shape
|
||||||
|
"output_channels": [32, 32, 32, 32, None],
|
||||||
|
"kernel_shapes": [5],
|
||||||
|
"strides": [1],
|
||||||
|
"activation": "elu",
|
||||||
|
},
|
||||||
|
"coord_type": "linear",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"refinement_core": {
|
||||||
|
"constructor": "iodine.modules.refinement.RefinementCore",
|
||||||
|
"encoder_net": {
|
||||||
|
"constructor": "iodine.modules.networks.CNN",
|
||||||
|
"mode": "avg_pool",
|
||||||
|
"cnn_opt": {
|
||||||
|
"output_channels": [32, 32, 32],
|
||||||
|
"strides": [2],
|
||||||
|
"kernel_shapes": [5],
|
||||||
|
"activation": "elu",
|
||||||
|
},
|
||||||
|
"mlp_opt": {
|
||||||
|
"output_sizes": [128],
|
||||||
|
"activation": "elu"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"recurrent_net": {
|
||||||
|
"constructor": "iodine.modules.networks.LSTM",
|
||||||
|
"hidden_sizes": [128],
|
||||||
|
},
|
||||||
|
"refinement_head": {
|
||||||
|
"constructor": "iodine.modules.refinement.ResHead"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"latent_dist": {
|
||||||
|
"constructor": "iodine.modules.distributions.LocScaleDistribution",
|
||||||
|
"dist": "normal",
|
||||||
|
"scale_act": "softplus",
|
||||||
|
"scale": "var",
|
||||||
|
"name": "latent_dist",
|
||||||
|
},
|
||||||
|
"output_dist": {
|
||||||
|
"constructor": "iodine.modules.distributions.MaskedMixture",
|
||||||
|
"num_components": num_components,
|
||||||
|
"component_dist": {
|
||||||
|
"constructor":
|
||||||
|
"iodine.modules.distributions.LocScaleDistribution",
|
||||||
|
"dist":
|
||||||
|
"logistic",
|
||||||
|
"scale":
|
||||||
|
"fixed",
|
||||||
|
"scale_val":
|
||||||
|
0.03,
|
||||||
|
"name":
|
||||||
|
"pixel_distribution",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"factor_evaluator": {
|
||||||
|
"constructor":
|
||||||
|
"iodine.modules.factor_eval.FactorRegressor",
|
||||||
|
"mapping": [
|
||||||
|
("color", 3, "scalar"),
|
||||||
|
("shape", 4, "categorical"),
|
||||||
|
("scale", 1, "scalar"),
|
||||||
|
("x", 1, "scalar"),
|
||||||
|
("y", 1, "scalar"),
|
||||||
|
("orientation", 2, "angle"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
optimizer = {
|
||||||
|
"constructor": "tensorflow.train.AdamOptimizer",
|
||||||
|
"learning_rate": {
|
||||||
|
"constructor": "tensorflow.train.exponential_decay",
|
||||||
|
"learning_rate": learn_rate,
|
||||||
|
"global_step": {
|
||||||
|
"constructor": "tensorflow.train.get_or_create_global_step"
|
||||||
|
},
|
||||||
|
"decay_steps": 1000000,
|
||||||
|
"decay_rate": 0.1,
|
||||||
|
},
|
||||||
|
"beta1": 0.95,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def tetrominoes():
|
||||||
|
n_z = 32 # number of latent dimensions
|
||||||
|
num_components = 4 # number of components (K)
|
||||||
|
num_iters = 5
|
||||||
|
checkpoint_dir = "iodine/checkpoints/tetrominoes"
|
||||||
|
|
||||||
|
# For the paper we used 8 GPUs with a batch size of 32 each.
|
||||||
|
# This means a total batch size of 256, which is too large for a single GPU.
|
||||||
|
# When reducing the batch size, the learning rate should also be lowered.
|
||||||
|
batch_size = 128
|
||||||
|
learn_rate = 0.0003 * math.sqrt(batch_size / 256)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"constructor": "iodine.modules.data.Tetrominoes",
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"path": "iodine/multi_object_datasets/tetrominoes_train.tfrecords",
|
||||||
|
}
|
||||||
|
|
||||||
|
model = {
|
||||||
|
"constructor": "iodine.modules.iodine.IODINE",
|
||||||
|
"n_z": n_z,
|
||||||
|
"num_components": num_components,
|
||||||
|
"num_iters": num_iters,
|
||||||
|
"iter_loss_weight": "linspace",
|
||||||
|
"coord_type": "cos",
|
||||||
|
"coord_freqs": 3,
|
||||||
|
"decoder": {
|
||||||
|
"constructor": "iodine.modules.decoder.ComponentDecoder",
|
||||||
|
"pixel_decoder": {
|
||||||
|
"constructor": "iodine.modules.networks.BroadcastConv",
|
||||||
|
"cnn_opt": {
|
||||||
|
# Final channels is irrelevant with target_output_shape
|
||||||
|
"output_channels": [32, 32, 32, 32, None],
|
||||||
|
"kernel_shapes": [5],
|
||||||
|
"strides": [1],
|
||||||
|
"activation": "elu",
|
||||||
|
},
|
||||||
|
"coord_type": "linear",
|
||||||
|
"coord_freqs": 3,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"refinement_core": {
|
||||||
|
"constructor": "iodine.modules.refinement.RefinementCore",
|
||||||
|
"encoder_net": {
|
||||||
|
"constructor": "iodine.modules.networks.CNN",
|
||||||
|
"mode": "avg_pool",
|
||||||
|
"cnn_opt": {
|
||||||
|
"output_channels": [32, 32, 32],
|
||||||
|
"strides": [2],
|
||||||
|
"kernel_shapes": [5],
|
||||||
|
"activation": "elu",
|
||||||
|
},
|
||||||
|
"mlp_opt": {
|
||||||
|
"output_sizes": [128],
|
||||||
|
"activation": "elu"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"recurrent_net": {
|
||||||
|
"constructor": "iodine.modules.networks.LSTM",
|
||||||
|
"hidden_sizes": [], # No recurrent layer used for this dataset
|
||||||
|
},
|
||||||
|
"refinement_head": {
|
||||||
|
"constructor": "iodine.modules.refinement.ResHead"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"latent_dist": {
|
||||||
|
"constructor": "iodine.modules.distributions.LocScaleDistribution",
|
||||||
|
"dist": "normal",
|
||||||
|
"scale_act": "softplus",
|
||||||
|
"scale": "var",
|
||||||
|
"name": "latent_dist",
|
||||||
|
},
|
||||||
|
"output_dist": {
|
||||||
|
"constructor": "iodine.modules.distributions.MaskedMixture",
|
||||||
|
"num_components": num_components,
|
||||||
|
"component_dist": {
|
||||||
|
"constructor":
|
||||||
|
"iodine.modules.distributions.LocScaleDistribution",
|
||||||
|
"dist":
|
||||||
|
"logistic",
|
||||||
|
"scale":
|
||||||
|
"fixed",
|
||||||
|
"scale_val":
|
||||||
|
0.03,
|
||||||
|
"name":
|
||||||
|
"pixel_distribution",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"factor_evaluator": {
|
||||||
|
"constructor":
|
||||||
|
"iodine.modules.factor_eval.FactorRegressor",
|
||||||
|
"mapping": [
|
||||||
|
("position", 2, "scalar"),
|
||||||
|
("color", 3, "scalar"),
|
||||||
|
("shape", 20, "categorical"),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
optimizer = {
|
||||||
|
"constructor": "tensorflow.train.AdamOptimizer",
|
||||||
|
"learning_rate": {
|
||||||
|
"constructor": "tensorflow.train.exponential_decay",
|
||||||
|
"learning_rate": learn_rate,
|
||||||
|
"global_step": {
|
||||||
|
"constructor": "tensorflow.train.get_or_create_global_step"
|
||||||
|
},
|
||||||
|
"decay_steps": 1000000,
|
||||||
|
"decay_rate": 0.1,
|
||||||
|
},
|
||||||
|
"beta1": 0.95,
|
||||||
|
}
|
||||||
Executable
+20
@@ -0,0 +1,20 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
pushd iodine
|
||||||
|
wget http://storage.googleapis.com/deepmind-research-iodine/iodine_checkpoints.zip
|
||||||
|
unzip iodine_checkpoints.zip
|
||||||
|
popd
|
||||||
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 148 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 17 KiB |
+202
@@ -0,0 +1,202 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# pylint: disable=g-importing-member, g-multiple-import, g-import-not-at-top
|
||||||
|
# pylint: disable=protected-access, g-bad-import-order, missing-docstring
|
||||||
|
# pylint: disable=unused-variable, invalid-name, no-value-for-parameter
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
|
import os.path
|
||||||
|
import warnings
|
||||||
|
from absl import logging
|
||||||
|
import numpy as np
|
||||||
|
from sacred import Experiment, SETTINGS
|
||||||
|
|
||||||
|
# Ignore all tensorflow deprecation warnings
|
||||||
|
logging._warn_preinit_stderr = 0
|
||||||
|
warnings.filterwarnings("ignore", module=".*tensorflow.*")
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
tf.logging.set_verbosity(tf.logging.ERROR)
|
||||||
|
import sonnet as snt
|
||||||
|
from sacred.stflow import LogFileWriter
|
||||||
|
from iodine.modules import utils
|
||||||
|
from iodine import configurations
|
||||||
|
|
||||||
|
SETTINGS.CONFIG.READ_ONLY_CONFIG = False
|
||||||
|
|
||||||
|
ex = Experiment("iodine")
|
||||||
|
|
||||||
|
|
||||||
|
@ex.config
|
||||||
|
def default_config():
|
||||||
|
continue_run = False # set to continue experiment from an existing checkpoint
|
||||||
|
checkpoint_dir = ("checkpoints/iodine"
|
||||||
|
) # if continue_run is False, "_{run_id}" will be appended
|
||||||
|
save_summaries_steps = 10
|
||||||
|
save_checkpoint_steps = 1000
|
||||||
|
|
||||||
|
n_z = 64 # number of latent dimensions
|
||||||
|
num_components = 7 # number of components (K)
|
||||||
|
num_iters = 5
|
||||||
|
|
||||||
|
learn_rate = 0.001
|
||||||
|
batch_size = 4
|
||||||
|
stop_after_steps = int(1e6)
|
||||||
|
|
||||||
|
# Details for the dataset, model and optimizer are left empty here.
|
||||||
|
# They can be found in the configurations for individual datasets,
|
||||||
|
# which are provided in configurations.py and added as named configs.
|
||||||
|
data = {} # Dataset details will go here
|
||||||
|
model = {} # Model details will go here
|
||||||
|
optimizer = {} # Optimizer details will go here
|
||||||
|
|
||||||
|
|
||||||
|
ex.named_config(configurations.clevr6)
|
||||||
|
ex.named_config(configurations.multi_dsprites)
|
||||||
|
ex.named_config(configurations.tetrominoes)
|
||||||
|
|
||||||
|
|
||||||
|
@ex.capture
|
||||||
|
def build(identifier, _config):
|
||||||
|
config_copy = deepcopy(_config[identifier])
|
||||||
|
return utils.build(config_copy, identifier=identifier)
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_step(model, dataset, optimizer):
|
||||||
|
loss, scalars, _ = model(dataset("train"))
|
||||||
|
global_step = tf.train.get_or_create_global_step()
|
||||||
|
grads = optimizer.compute_gradients(loss)
|
||||||
|
gradients, variables = zip(*grads)
|
||||||
|
global_norm = tf.global_norm(gradients)
|
||||||
|
gradients, global_norm = tf.clip_by_global_norm(
|
||||||
|
gradients, 5.0, use_norm=global_norm)
|
||||||
|
grads = zip(gradients, variables)
|
||||||
|
train_op = optimizer.apply_gradients(grads, global_step=global_step)
|
||||||
|
|
||||||
|
with tf.control_dependencies([train_op]):
|
||||||
|
overview = model.get_overview_images(dataset("summary"))
|
||||||
|
scalars["debug/global_grad_norm"] = global_norm
|
||||||
|
|
||||||
|
summaries = {
|
||||||
|
k: tf.summary.scalar(k, v) for k, v in scalars.items()
|
||||||
|
}
|
||||||
|
summaries.update(
|
||||||
|
{k: tf.summary.image(k, v) for k, v in overview.items()})
|
||||||
|
|
||||||
|
return tf.identity(global_step), scalars, train_op
|
||||||
|
|
||||||
|
|
||||||
|
@ex.capture
|
||||||
|
def get_checkpoint_dir(continue_run, checkpoint_dir, _run, _log):
|
||||||
|
if continue_run:
|
||||||
|
assert os.path.exists(checkpoint_dir)
|
||||||
|
_log.info("Continuing run from checkpoint at {}".format(checkpoint_dir))
|
||||||
|
return checkpoint_dir
|
||||||
|
|
||||||
|
run_id = _run._id
|
||||||
|
if run_id is None: # then no observer was added that provided an _id
|
||||||
|
if not _run.unobserved:
|
||||||
|
_log.warning(
|
||||||
|
"No run_id given or provided by an Observer. (Re-)using run_id=1.")
|
||||||
|
run_id = 1
|
||||||
|
checkpoint_dir = checkpoint_dir + "_{run_id}".format(run_id=run_id)
|
||||||
|
_log.info(
|
||||||
|
"Starting a new run using checkpoint dir: '{}'".format(checkpoint_dir))
|
||||||
|
return checkpoint_dir
|
||||||
|
|
||||||
|
|
||||||
|
@ex.capture
|
||||||
|
def get_session(chkp_dir, loss, stop_after_steps, save_summaries_steps,
|
||||||
|
save_checkpoint_steps):
|
||||||
|
config = tf.ConfigProto()
|
||||||
|
config.gpu_options.allow_growth = True
|
||||||
|
|
||||||
|
hooks = [
|
||||||
|
tf.train.StopAtStepHook(last_step=stop_after_steps),
|
||||||
|
tf.train.NanTensorHook(loss),
|
||||||
|
]
|
||||||
|
|
||||||
|
return tf.train.MonitoredTrainingSession(
|
||||||
|
hooks=hooks,
|
||||||
|
config=config,
|
||||||
|
checkpoint_dir=chkp_dir,
|
||||||
|
save_summaries_steps=save_summaries_steps,
|
||||||
|
save_checkpoint_steps=save_checkpoint_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ex.command(unobserved=True)
|
||||||
|
def load_checkpoint(use_placeholder=False, session=None):
|
||||||
|
dataset = build("data")
|
||||||
|
model = build("model")
|
||||||
|
if use_placeholder:
|
||||||
|
inputs = dataset.get_placeholders()
|
||||||
|
else:
|
||||||
|
inputs = dataset()
|
||||||
|
|
||||||
|
info = model.eval(inputs)
|
||||||
|
if session is None:
|
||||||
|
session = tf.Session()
|
||||||
|
saver = tf.train.Saver()
|
||||||
|
checkpoint_dir = get_checkpoint_dir()
|
||||||
|
checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
|
||||||
|
saver.restore(session, checkpoint_file)
|
||||||
|
|
||||||
|
print('Successfully restored Checkpoint "{}"'.format(checkpoint_file))
|
||||||
|
# print variables
|
||||||
|
variables = tf.global_variables() + tf.local_variables()
|
||||||
|
for row in snt.format_variables(variables, join_lines=False):
|
||||||
|
print(row)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"session": session,
|
||||||
|
"model": model,
|
||||||
|
"info": info,
|
||||||
|
"inputs": inputs,
|
||||||
|
"dataset": dataset,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ex.automain
|
||||||
|
@LogFileWriter(ex)
|
||||||
|
def main(save_summaries_steps):
|
||||||
|
checkpoint_dir = get_checkpoint_dir()
|
||||||
|
|
||||||
|
dataset = build("data")
|
||||||
|
model = build("model")
|
||||||
|
optimizer = build("optimizer")
|
||||||
|
gstep, train_step_exports, train_op = get_train_step(model, dataset,
|
||||||
|
optimizer)
|
||||||
|
|
||||||
|
loss, ari = [], []
|
||||||
|
with get_session(checkpoint_dir, train_step_exports["loss/total"]) as sess:
|
||||||
|
while not sess.should_stop():
|
||||||
|
out = sess.run({
|
||||||
|
"step": gstep,
|
||||||
|
"loss": train_step_exports["loss/total"],
|
||||||
|
"ari": train_step_exports["loss/ari_nobg"],
|
||||||
|
"train": train_op,
|
||||||
|
})
|
||||||
|
loss.append(out["loss"])
|
||||||
|
ari.append(out["ari"])
|
||||||
|
step = out["step"]
|
||||||
|
if step % save_summaries_steps == 0:
|
||||||
|
mean_loss = np.mean(loss)
|
||||||
|
mean_ari = np.mean(ari)
|
||||||
|
ex.log_scalar("loss", mean_loss, step)
|
||||||
|
ex.log_scalar("ari", mean_ari, step)
|
||||||
|
print("{step:>6d} Loss: {loss: >12.2f}\t\tARI-nobg:{ari: >6.2f}".format(
|
||||||
|
step=step, loss=mean_loss, ari=mean_ari))
|
||||||
|
loss, ari = [], []
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
@@ -0,0 +1,264 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Data loading functionality for IODINE."""
|
||||||
|
# pylint: disable=g-multiple-import, missing-docstring, unused-import
|
||||||
|
import os.path
|
||||||
|
|
||||||
|
from iodine.modules.utils import flatten_all_but_last, ensure_3d
|
||||||
|
from multi_object_datasets import (
|
||||||
|
clevr_with_masks,
|
||||||
|
multi_dsprites,
|
||||||
|
tetrominoes,
|
||||||
|
objects_room,
|
||||||
|
)
|
||||||
|
from shapeguard import ShapeGuard
|
||||||
|
import sonnet as snt
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
|
||||||
|
class IODINEDataset(snt.AbstractModule):
|
||||||
|
num_true_objects = 1
|
||||||
|
num_channels = 3
|
||||||
|
|
||||||
|
factors = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path,
|
||||||
|
batch_size,
|
||||||
|
image_dim,
|
||||||
|
crop_region=None,
|
||||||
|
shuffle_buffer=1000,
|
||||||
|
max_num_objects=None,
|
||||||
|
min_num_objects=None,
|
||||||
|
grayscale=False,
|
||||||
|
name="dataset",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(name=name)
|
||||||
|
self.path = os.path.abspath(os.path.expanduser(path))
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.crop_region = crop_region
|
||||||
|
self.image_dim = image_dim
|
||||||
|
self.shuffle_buffer = shuffle_buffer
|
||||||
|
self.max_num_objects = max_num_objects
|
||||||
|
self.min_num_objects = min_num_objects
|
||||||
|
self.grayscale = grayscale
|
||||||
|
self.dataset = None
|
||||||
|
|
||||||
|
def _build(self, subset="train"):
|
||||||
|
dataset = self.dataset
|
||||||
|
|
||||||
|
# filter by number of objects
|
||||||
|
if self.max_num_objects is not None or self.min_num_objects is not None:
|
||||||
|
dataset = self.dataset.filter(self.filter_by_num_objects)
|
||||||
|
|
||||||
|
if subset == "train":
|
||||||
|
# normal mode returns a shuffled dataset iterator
|
||||||
|
if self.shuffle_buffer is not None:
|
||||||
|
dataset = dataset.shuffle(self.shuffle_buffer)
|
||||||
|
elif subset == "summary":
|
||||||
|
# for generating summaries and overview images
|
||||||
|
# returns a single fixed batch
|
||||||
|
dataset = dataset.take(self.batch_size)
|
||||||
|
|
||||||
|
# repeat and batch
|
||||||
|
dataset = dataset.repeat().batch(self.batch_size, drop_remainder=True)
|
||||||
|
iterator = dataset.make_one_shot_iterator()
|
||||||
|
data = iterator.get_next()
|
||||||
|
|
||||||
|
# preprocess the data to ensure correct format, scale images etc.
|
||||||
|
data = self.preprocess(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def filter_by_num_objects(self, d):
|
||||||
|
if "visibility" not in d:
|
||||||
|
return tf.constant(True)
|
||||||
|
min_num_objects = self.max_num_objects or 0
|
||||||
|
max_num_objects = self.max_num_objects or 6
|
||||||
|
|
||||||
|
min_predicate = tf.greater_equal(
|
||||||
|
tf.reduce_sum(d["visibility"]),
|
||||||
|
tf.constant(min_num_objects - 1e-5, dtype=tf.float32),
|
||||||
|
)
|
||||||
|
max_predicate = tf.less_equal(
|
||||||
|
tf.reduce_sum(d["visibility"]),
|
||||||
|
tf.constant(max_num_objects + 1e-5, dtype=tf.float32),
|
||||||
|
)
|
||||||
|
return tf.logical_and(min_predicate, max_predicate)
|
||||||
|
|
||||||
|
def preprocess(self, data):
|
||||||
|
sg = ShapeGuard(dims={
|
||||||
|
"B": self.batch_size,
|
||||||
|
"H": self.image_dim[0],
|
||||||
|
"W": self.image_dim[1]
|
||||||
|
})
|
||||||
|
image = sg.guard(data["image"], "B, h, w, C")
|
||||||
|
mask = sg.guard(data["mask"], "B, L, h, w, 1")
|
||||||
|
|
||||||
|
# to float
|
||||||
|
image = tf.cast(image, tf.float32) / 255.0
|
||||||
|
mask = tf.cast(mask, tf.float32) / 255.0
|
||||||
|
|
||||||
|
# crop
|
||||||
|
if self.crop_region is not None:
|
||||||
|
height_slice = slice(self.crop_region[0][0], self.crop_region[0][1])
|
||||||
|
width_slice = slice(self.crop_region[1][0], self.crop_region[1][1])
|
||||||
|
image = image[:, height_slice, width_slice, :]
|
||||||
|
|
||||||
|
mask = mask[:, :, height_slice, width_slice, :]
|
||||||
|
|
||||||
|
flat_mask, unflatten = flatten_all_but_last(mask, n_dims=3)
|
||||||
|
|
||||||
|
# rescale
|
||||||
|
size = tf.constant(
|
||||||
|
self.image_dim, dtype=tf.int32, shape=[2], verify_shape=True)
|
||||||
|
image = tf.image.resize_images(
|
||||||
|
image, size, method=tf.image.ResizeMethod.BILINEAR)
|
||||||
|
mask = tf.image.resize_images(
|
||||||
|
flat_mask, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
||||||
|
|
||||||
|
if self.grayscale:
|
||||||
|
image = tf.reduce_mean(image, axis=-1, keepdims=True)
|
||||||
|
|
||||||
|
output = {
|
||||||
|
"image": sg.guard(image[:, None], "B, T, H, W, C"),
|
||||||
|
"mask": sg.guard(unflatten(mask)[:, None], "B, T, L, H, W, 1"),
|
||||||
|
"factors": self.preprocess_factors(data, sg),
|
||||||
|
}
|
||||||
|
|
||||||
|
if "visibility" in data:
|
||||||
|
output["visibility"] = sg.guard(data["visibility"], "B, L")
|
||||||
|
else:
|
||||||
|
output["visibility"] = tf.ones(sg["B, L"], dtype=tf.float32)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def preprocess_factors(self, data, sg):
|
||||||
|
return {
|
||||||
|
name: sg.guard(ensure_3d(data[name]), "B, L, *")
|
||||||
|
for name in self.factors
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_placeholders(self, batch_size=None):
|
||||||
|
batch_size = batch_size or self.batch_size
|
||||||
|
sg = ShapeGuard(
|
||||||
|
dims={
|
||||||
|
"B": batch_size,
|
||||||
|
"H": self.image_dim[0],
|
||||||
|
"W": self.image_dim[1],
|
||||||
|
"L": self.num_true_objects,
|
||||||
|
"C": 3,
|
||||||
|
"T": 1,
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
"image": tf.placeholder(dtype=tf.float32, shape=sg["B, T, H, W, C"]),
|
||||||
|
"mask": tf.placeholder(dtype=tf.float32, shape=sg["B, T, L, H, W, 1"]),
|
||||||
|
"visibility": tf.placeholder(dtype=tf.float32, shape=sg["B, L"]),
|
||||||
|
"factors": {
|
||||||
|
name:
|
||||||
|
tf.placeholder(dtype=dtype, shape=sg["B, L, {}".format(size)])
|
||||||
|
for name, (dtype, size) in self.factors
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CLEVR(IODINEDataset):
|
||||||
|
num_true_objects = 11
|
||||||
|
num_channels = 3
|
||||||
|
factors = {
|
||||||
|
"color": (tf.uint8, 1),
|
||||||
|
"shape": (tf.uint8, 1),
|
||||||
|
"size": (tf.uint8, 1),
|
||||||
|
"position": (tf.float32, 3),
|
||||||
|
"rotation": (tf.float32, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path,
|
||||||
|
crop_region=((29, 221), (64, 256)),
|
||||||
|
image_dim=(128, 128),
|
||||||
|
name="clevr",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
path=path,
|
||||||
|
crop_region=crop_region,
|
||||||
|
image_dim=image_dim,
|
||||||
|
name=name,
|
||||||
|
**kwargs)
|
||||||
|
self.dataset = clevr_with_masks.dataset(self.path)
|
||||||
|
|
||||||
|
def preprocess_factors(self, data, sg):
|
||||||
|
|
||||||
|
return {
|
||||||
|
"color": sg.guard(ensure_3d(data["color"]), "B, L, 1"),
|
||||||
|
"shape": sg.guard(ensure_3d(data["shape"]), "B, L, 1"),
|
||||||
|
"size": sg.guard(ensure_3d(data["color"]), "B, L, 1"),
|
||||||
|
"position": sg.guard(ensure_3d(data["pixel_coords"]), "B, L, 3"),
|
||||||
|
"rotation": sg.guard(ensure_3d(data["rotation"]), "B, L, 1"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MultiDSprites(IODINEDataset):
|
||||||
|
num_true_objects = 6
|
||||||
|
num_channels = 3
|
||||||
|
factors = {
|
||||||
|
"color": (tf.float32, 3),
|
||||||
|
"shape": (tf.uint8, 1),
|
||||||
|
"scale": (tf.float32, 1),
|
||||||
|
"x": (tf.float32, 1),
|
||||||
|
"y": (tf.float32, 1),
|
||||||
|
"orientation": (tf.float32, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path,
|
||||||
|
# variant from ['binarized', 'colored_on_grayscale', 'colored_on_colored']
|
||||||
|
dataset_variant="colored_on_grayscale",
|
||||||
|
image_dim=(64, 64),
|
||||||
|
name="multi_dsprites",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(path=path, name=name, image_dim=image_dim, **kwargs)
|
||||||
|
self.dataset_variant = dataset_variant
|
||||||
|
self.dataset = multi_dsprites.dataset(self.path, self.dataset_variant)
|
||||||
|
|
||||||
|
|
||||||
|
class Tetrominoes(IODINEDataset):
|
||||||
|
num_true_objects = 6
|
||||||
|
num_channels = 3
|
||||||
|
factors = {
|
||||||
|
"color": (tf.uint8, 3),
|
||||||
|
"shape": (tf.uint8, 1),
|
||||||
|
"position": (tf.float32, 2),
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, path, image_dim=(35, 35), name="tetrominoes", **kwargs):
|
||||||
|
super().__init__(path=path, name=name, image_dim=image_dim, **kwargs)
|
||||||
|
self.dataset = tetrominoes.dataset(self.path)
|
||||||
|
|
||||||
|
def preprocess_factors(self, data, sg):
|
||||||
|
pos_x = ensure_3d(data["x"])
|
||||||
|
pos_y = ensure_3d(data["y"])
|
||||||
|
position = tf.concat([pos_x, pos_y], axis=2)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"color": sg.guard(ensure_3d(data["color"]), "B, L, 3"),
|
||||||
|
"shape": sg.guard(ensure_3d(data["shape"]), "B, L, 1"),
|
||||||
|
"position": sg.guard(ensure_3d(position), "B, L, 2"),
|
||||||
|
}
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Decoders for rendering images."""
|
||||||
|
# pylint: disable=missing-docstring
|
||||||
|
from iodine.modules.distributions import MixtureParameters
|
||||||
|
import shapeguard
|
||||||
|
import sonnet as snt
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentDecoder(snt.AbstractModule):
|
||||||
|
|
||||||
|
def __init__(self, pixel_decoder, name="component_decoder"):
|
||||||
|
super().__init__(name=name)
|
||||||
|
self._pixel_decoder = pixel_decoder
|
||||||
|
self._sg = shapeguard.ShapeGuard()
|
||||||
|
|
||||||
|
def set_output_shapes(self, pixel, mask):
|
||||||
|
self._sg.guard(pixel, "K, H, W, Cp")
|
||||||
|
self._sg.guard(mask, "K, H, W, 1")
|
||||||
|
self._pixel_decoder.set_output_shapes(self._sg["H, W, 1 + Cp"])
|
||||||
|
|
||||||
|
def _build(self, z):
|
||||||
|
self._sg.guard(z, "B, K, Z")
|
||||||
|
z_flat = self._sg.reshape(z, "B*K, Z")
|
||||||
|
pixel_params = self._pixel_decoder(z_flat).params
|
||||||
|
|
||||||
|
self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp")
|
||||||
|
mask_params = pixel_params[Ellipsis, 0:1]
|
||||||
|
pixel_params = pixel_params[Ellipsis, 1:]
|
||||||
|
|
||||||
|
output = MixtureParameters(
|
||||||
|
pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"),
|
||||||
|
mask=self._sg.reshape(mask_params, "B, K, H, W, 1"),
|
||||||
|
)
|
||||||
|
|
||||||
|
del self._sg.B
|
||||||
|
return output
|
||||||
@@ -0,0 +1,223 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Collection of sonnet modules that wrap useful distributions."""
|
||||||
|
# pylint: disable=missing-docstring, g-doc-args, g-short-docstring-punctuation
|
||||||
|
# pylint: disable=g-space-before-docstring-summary
|
||||||
|
# pylint: disable=g-no-space-after-docstring-summary
|
||||||
|
import collections
|
||||||
|
from iodine.modules.utils import get_act_func
|
||||||
|
from iodine.modules.utils import get_distribution
|
||||||
|
import shapeguard
|
||||||
|
import sonnet as snt
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
import tensorflow_probability as tfp
|
||||||
|
|
||||||
|
|
||||||
|
tfd = tfp.distributions
|
||||||
|
|
||||||
|
FlatParameters = collections.namedtuple("ParameterOut", ["params"])
|
||||||
|
MixtureParameters = collections.namedtuple("MixtureOut", ["pixel", "mask"])
|
||||||
|
|
||||||
|
|
||||||
|
class DistributionModule(snt.AbstractModule):
|
||||||
|
"""Distribution Base class supporting shape inference & default priors."""
|
||||||
|
|
||||||
|
def __init__(self, name="distribution"):
|
||||||
|
super().__init__(name=name)
|
||||||
|
self._output_shape = None
|
||||||
|
|
||||||
|
def set_output_shape(self, shape):
|
||||||
|
self._output_shape = shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
return self._output_shape
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shapes(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_default_prior(self, batch_dim=(1,)):
|
||||||
|
return self(
|
||||||
|
tf.zeros(list(batch_dim) + self.input_shapes.params, dtype=tf.float32))
|
||||||
|
|
||||||
|
|
||||||
|
class BernoulliOutput(DistributionModule):
|
||||||
|
|
||||||
|
def __init__(self, name="bernoulli_output"):
|
||||||
|
super().__init__(name=name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shapes(self):
|
||||||
|
return FlatParameters(self.output_shape)
|
||||||
|
|
||||||
|
def _build(self, params):
|
||||||
|
return tfd.Independent(
|
||||||
|
tfd.Bernoulli(logits=params, dtype=tf.float32),
|
||||||
|
reinterpreted_batch_ndims=1)
|
||||||
|
|
||||||
|
|
||||||
|
class LocScaleDistribution(DistributionModule):
|
||||||
|
"""Generic IID location / scale distribution.
|
||||||
|
|
||||||
|
Input parameters are concatenation of location and scale (2*Z,)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dist: Distribution or str Kind of distribution used. Supports Normal,
|
||||||
|
Logistic, Laplace, and StudentT distributions.
|
||||||
|
dist_kwargs: dict custom keyword arguments for the distribution
|
||||||
|
scale_act: function or str or None activation function to be applied to
|
||||||
|
the scale input
|
||||||
|
scale: str
|
||||||
|
different modes for computing the scale:
|
||||||
|
* stddev: scale is computed as scale_act(s)
|
||||||
|
* var: scale is computed as sqrt(scale_act(s))
|
||||||
|
* prec: scale is computed as 1./scale_act(s)
|
||||||
|
* fixed: scale is a global variable (same for all pixels) if
|
||||||
|
scale_val==-1. then it is a trainable variable initialized to 0.1
|
||||||
|
else it is fixed to scale_val (input shape is only (Z,) in this
|
||||||
|
case)
|
||||||
|
scale_val: float determines the scale value (only used if scale=='fixed').
|
||||||
|
loc_act: function or str or None activation function to be applied to the
|
||||||
|
location input. Supports optional activation functions for scale and
|
||||||
|
location.
|
||||||
|
Supports different "modes" for scaling:
|
||||||
|
* stddev:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dist=tfd.Normal,
|
||||||
|
dist_kwargs=None,
|
||||||
|
scale_act=tf.exp,
|
||||||
|
scale="stddev",
|
||||||
|
scale_val=1.0,
|
||||||
|
loc_act=None,
|
||||||
|
name="loc_scale_dist",
|
||||||
|
):
|
||||||
|
super().__init__(name=name)
|
||||||
|
self._scale_act = get_act_func(scale_act)
|
||||||
|
self._loc_act = get_act_func(loc_act)
|
||||||
|
# supports Normal, Logstic, Laplace, StudentT
|
||||||
|
self._dist = get_distribution(dist)
|
||||||
|
self._dist_kwargs = dist_kwargs or {}
|
||||||
|
|
||||||
|
assert scale in ["stddev", "var", "prec", "fixed"], scale
|
||||||
|
self._scale = scale
|
||||||
|
self._scale_val = scale_val
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shapes(self):
|
||||||
|
if self._scale == "fixed":
|
||||||
|
param_shape = self.output_shape
|
||||||
|
else:
|
||||||
|
param_shape = self.output_shape[:-1] + [self.output_shape[-1] * 2]
|
||||||
|
return FlatParameters(param_shape)
|
||||||
|
|
||||||
|
def _build(self, params):
|
||||||
|
if self._scale == "fixed":
|
||||||
|
loc = params
|
||||||
|
scale = None # set later
|
||||||
|
else:
|
||||||
|
n_channels = params.get_shape().as_list()[-1]
|
||||||
|
assert n_channels % 2 == 0
|
||||||
|
assert n_channels // 2 == self.output_shape[-1]
|
||||||
|
loc = params[Ellipsis, :n_channels // 2]
|
||||||
|
scale = params[Ellipsis, n_channels // 2:]
|
||||||
|
|
||||||
|
# apply activation functions
|
||||||
|
if self._scale != "fixed":
|
||||||
|
scale = self._scale_act(scale)
|
||||||
|
loc = self._loc_act(loc)
|
||||||
|
|
||||||
|
# apply the correct parametrization
|
||||||
|
if self._scale == "var":
|
||||||
|
scale = tf.sqrt(scale)
|
||||||
|
elif self._scale == "prec":
|
||||||
|
scale = tf.reciprocal(scale)
|
||||||
|
elif self._scale == "fixed":
|
||||||
|
if self._scale_val == -1.0:
|
||||||
|
scale_val = tf.get_variable(
|
||||||
|
"scale", initializer=tf.constant(0.1, dtype=tf.float32))
|
||||||
|
else:
|
||||||
|
scale_val = self._scale_val
|
||||||
|
scale = tf.ones_like(loc) * scale_val
|
||||||
|
# else 'stddev'
|
||||||
|
|
||||||
|
dist = self._dist(loc=loc, scale=scale, **self._dist_kwargs)
|
||||||
|
|
||||||
|
return tfd.Independent(dist, reinterpreted_batch_ndims=1)
|
||||||
|
|
||||||
|
|
||||||
|
class MaskedMixture(DistributionModule):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_components,
|
||||||
|
component_dist,
|
||||||
|
mask_activation=None,
|
||||||
|
name="masked_mixture",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Spatial Mixture Model composed of a categorical masking distribution and
|
||||||
|
a custom pixel-wise component distribution (usually logistic or
|
||||||
|
gaussian).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_components: int Number of mixture components >= 2
|
||||||
|
component_dist: the distribution to use for the individual components
|
||||||
|
mask_activation: str or function or None activation function that
|
||||||
|
should be applied to the mask before the softmax.
|
||||||
|
name: str
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__(name=name)
|
||||||
|
self._num_components = num_components
|
||||||
|
self._dist = component_dist
|
||||||
|
self._mask_activation = get_act_func(mask_activation)
|
||||||
|
|
||||||
|
def set_output_shape(self, shape):
|
||||||
|
super().set_output_shape(shape)
|
||||||
|
self._dist.set_output_shape(shape)
|
||||||
|
|
||||||
|
def _build(self, pixel, mask):
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
# MASKING
|
||||||
|
sg.guard(mask, "B, K, H, W, 1")
|
||||||
|
mask = tf.transpose(mask, perm=[0, 2, 3, 4, 1])
|
||||||
|
mask = sg.reshape(mask, "B, H, W, K")
|
||||||
|
mask = self._mask_activation(mask)
|
||||||
|
mask = mask[:, tf.newaxis] # add K=1 axis since K is removed by mixture
|
||||||
|
mix_dist = tfd.Categorical(logits=mask)
|
||||||
|
|
||||||
|
# COMPONENTS
|
||||||
|
sg.guard(pixel, "B, K, H, W, Cp")
|
||||||
|
params = tf.transpose(pixel, perm=[0, 2, 3, 1, 4])
|
||||||
|
params = params[:, tf.newaxis] # add K=1 axis since K is removed by mixture
|
||||||
|
dist = self._dist(params)
|
||||||
|
return tfd.MixtureSameFamily(
|
||||||
|
mixture_distribution=mix_dist, components_distribution=dist)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shapes(self):
|
||||||
|
pixel = [self._num_components] + self._dist.input_shapes.params
|
||||||
|
mask = pixel[:-1] + [1]
|
||||||
|
return MixtureParameters(pixel, mask)
|
||||||
|
|
||||||
|
def get_default_prior(self, batch_dim=(1,)):
|
||||||
|
pixel = tf.zeros(
|
||||||
|
list(batch_dim) + self.input_shapes.pixel, dtype=tf.float32)
|
||||||
|
mask = tf.zeros(list(batch_dim) + self.input_shapes.mask, dtype=tf.float32)
|
||||||
|
return self(pixel, mask)
|
||||||
@@ -0,0 +1,206 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Factor Evaluation Module."""
|
||||||
|
# pylint: disable=unused-variable
|
||||||
|
|
||||||
|
import collections
|
||||||
|
import functools
|
||||||
|
from iodine.modules import utils
|
||||||
|
import shapeguard
|
||||||
|
import sonnet as snt
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
Factor = collections.namedtuple("Factor", ["name", "size", "type"])
|
||||||
|
|
||||||
|
|
||||||
|
class FactorRegressor(snt.AbstractModule):
|
||||||
|
"""Assess representations by learning a linear mapping to latents."""
|
||||||
|
|
||||||
|
def __init__(self, mapping=None, name="repres_content"):
|
||||||
|
super().__init__(name=name)
|
||||||
|
if mapping is None:
|
||||||
|
self._mapping = [
|
||||||
|
Factor("color", 3, "scalar"),
|
||||||
|
Factor("shape", 4, "categorical"),
|
||||||
|
Factor("scale", 1, "scalar"),
|
||||||
|
Factor("x", 1, "scalar"),
|
||||||
|
Factor("y", 1, "scalar"),
|
||||||
|
Factor("orientation", 2, "angle"),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self._mapping = [Factor(*m) for m in mapping]
|
||||||
|
|
||||||
|
def _build(self, z, latent, visibility, pred_mask, true_mask):
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
z = sg.guard(z, "B, K, Z")
|
||||||
|
pred_mask = sg.guard(pred_mask, "B, K, H, W, 1")
|
||||||
|
true_mask = sg.guard(true_mask, "B, L, H, W, 1")
|
||||||
|
|
||||||
|
visibility = sg.guard(visibility, "B, L")
|
||||||
|
num_visible_obj = tf.reduce_sum(visibility)
|
||||||
|
|
||||||
|
# Map z to predictions for all latents
|
||||||
|
sg.M = sum([m.size for m in self._mapping])
|
||||||
|
self.predictor = snt.Linear(sg.M, name="predict_latents")
|
||||||
|
z_flat = sg.reshape(z, "B*K, Z")
|
||||||
|
all_preds = sg.guard(self.predictor(z_flat), "B*K, M")
|
||||||
|
all_preds = sg.reshape(all_preds, "B, 1, K, M")
|
||||||
|
all_preds = tf.tile(all_preds, sg["1, L, 1, 1"])
|
||||||
|
|
||||||
|
# prepare latents
|
||||||
|
latents = {}
|
||||||
|
mean_var_tot = {}
|
||||||
|
for m in self._mapping:
|
||||||
|
with tf.name_scope(m.name):
|
||||||
|
# preprocess, reshape, and tile
|
||||||
|
lat_preprocess = self.get_preprocessing(m)
|
||||||
|
lat = sg.guard(
|
||||||
|
lat_preprocess(latent[m.name]), "B, L, {}".format(m.size))
|
||||||
|
# compute mean over latent by training a variable using mse
|
||||||
|
if m.type in {"scalar", "angle"}:
|
||||||
|
mvt = utils.OnlineMeanVarEstimator(
|
||||||
|
axis=[0, 1], ddof=1, name="{}_mean_var".format(m.name))
|
||||||
|
mean_var_tot[m.name] = mvt(lat, visibility[:, :, tf.newaxis])
|
||||||
|
|
||||||
|
lat = tf.reshape(lat, sg["B, L, 1"] + [-1])
|
||||||
|
lat = tf.tile(lat, sg["1, 1, K, 1"])
|
||||||
|
latents[m.name] = lat
|
||||||
|
|
||||||
|
# prepare predictions
|
||||||
|
idx = 0
|
||||||
|
predictions = {}
|
||||||
|
for m in self._mapping:
|
||||||
|
with tf.name_scope(m.name):
|
||||||
|
assert m.name in latent, "{} not in {}".format(m.name, latent.keys())
|
||||||
|
pred = all_preds[Ellipsis, idx:idx + m.size]
|
||||||
|
predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size))
|
||||||
|
idx += m.size
|
||||||
|
|
||||||
|
# compute error
|
||||||
|
total_pairwise_errors = None
|
||||||
|
for m in self._mapping:
|
||||||
|
with tf.name_scope(m.name):
|
||||||
|
error_fn = self.get_error_func(m)
|
||||||
|
sg.guard(latents[m.name], "B, L, K, {}".format(m.size))
|
||||||
|
sg.guard(predictions[m.name], "B, L, K, {}".format(m.size))
|
||||||
|
err = error_fn(latents[m.name], predictions[m.name])
|
||||||
|
sg.guard(err, "B, L, K")
|
||||||
|
if total_pairwise_errors is None:
|
||||||
|
total_pairwise_errors = err
|
||||||
|
else:
|
||||||
|
total_pairwise_errors += err
|
||||||
|
|
||||||
|
# determine best assignment by comparing masks
|
||||||
|
obj_mask = true_mask[:, :, tf.newaxis]
|
||||||
|
pred_mask = pred_mask[:, tf.newaxis]
|
||||||
|
pairwise_overlap = tf.reduce_sum(obj_mask * pred_mask, axis=[3, 4, 5])
|
||||||
|
best_match = sg.guard(tf.argmax(pairwise_overlap, axis=2), "B, L")
|
||||||
|
assignment = tf.one_hot(best_match, sg.K)
|
||||||
|
assignment *= visibility[:, :, tf.newaxis] # Mask non-visible objects
|
||||||
|
|
||||||
|
# total error
|
||||||
|
total_error = (
|
||||||
|
tf.reduce_sum(assignment * total_pairwise_errors) / num_visible_obj)
|
||||||
|
|
||||||
|
# compute scalars
|
||||||
|
monitored_scalars = {}
|
||||||
|
for m in self._mapping:
|
||||||
|
with tf.name_scope(m.name):
|
||||||
|
metric = self.get_metric(m)
|
||||||
|
scalar = metric(
|
||||||
|
latents[m.name],
|
||||||
|
predictions[m.name],
|
||||||
|
assignment[:, :, :, tf.newaxis],
|
||||||
|
mean_var_tot.get(m.name),
|
||||||
|
num_visible_obj,
|
||||||
|
)
|
||||||
|
monitored_scalars[m.name] = scalar
|
||||||
|
return total_error, monitored_scalars, mean_var_tot, predictions, assignment
|
||||||
|
|
||||||
|
@snt.reuse_variables
|
||||||
|
def predict(self, z):
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
z = sg.guard(z, "B, Z")
|
||||||
|
all_preds = sg.guard(self.predictor(z), "B, M")
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
predictions = {}
|
||||||
|
for m in self._mapping:
|
||||||
|
with tf.name_scope(m.name):
|
||||||
|
pred = all_preds[:, idx:idx + m.size]
|
||||||
|
predictions[m.name] = sg.guard(pred, "B, {}".format(m.size))
|
||||||
|
idx += m.size
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_error_func(factor):
|
||||||
|
if factor.type in {"scalar", "angle"}:
|
||||||
|
return sse
|
||||||
|
elif factor.type == "categorical":
|
||||||
|
return functools.partial(
|
||||||
|
tf.losses.softmax_cross_entropy, reduction="none")
|
||||||
|
else:
|
||||||
|
raise KeyError(factor.type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_metric(factor):
|
||||||
|
if factor.type in {"scalar", "angle"}:
|
||||||
|
return r2
|
||||||
|
elif factor.type == "categorical":
|
||||||
|
return accuracy
|
||||||
|
else:
|
||||||
|
raise KeyError(factor.type)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def one_hot(f, nr_categories):
|
||||||
|
return tf.one_hot(tf.cast(f[Ellipsis, 0], tf.int32), depth=nr_categories)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def angle_to_vector(theta):
|
||||||
|
return tf.concat([tf.math.cos(theta), tf.math.sin(theta)], axis=-1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_preprocessing(factor):
|
||||||
|
if factor.type == "scalar":
|
||||||
|
return tf.identity
|
||||||
|
elif factor.type == "categorical":
|
||||||
|
return functools.partial(
|
||||||
|
FactorRegressor.one_hot, nr_categories=factor.size)
|
||||||
|
elif factor.type == "angle":
|
||||||
|
return FactorRegressor.angle_to_vector
|
||||||
|
else:
|
||||||
|
raise KeyError(factor.type)
|
||||||
|
|
||||||
|
|
||||||
|
def sse(true, pred):
|
||||||
|
# run our own sum squared error because we want to reduce sum over last dim
|
||||||
|
return tf.reduce_sum(tf.square(true - pred), axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def accuracy(labels, logits, assignment, mean_var_tot, num_vis):
|
||||||
|
del mean_var_tot # unused
|
||||||
|
pred = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||||||
|
labels = tf.argmax(labels, axis=-1, output_type=tf.int32)
|
||||||
|
correct = tf.cast(tf.equal(labels, pred), tf.float32)
|
||||||
|
return tf.reduce_sum(correct * assignment[Ellipsis, 0]) / num_vis
|
||||||
|
|
||||||
|
|
||||||
|
def r2(labels, pred, assignment, mean_var_tot, num_vis):
|
||||||
|
del num_vis # unused
|
||||||
|
mean, var, _ = mean_var_tot
|
||||||
|
# labels, pred: (B, L, K, n)
|
||||||
|
ss_res = tf.reduce_sum(tf.square(labels - pred) * assignment, axis=2)
|
||||||
|
ss_tot = var[tf.newaxis, tf.newaxis, :] # (1, 1, n)
|
||||||
|
return tf.reduce_mean(1.0 - ss_res / ss_tot)
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,300 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Network modules."""
|
||||||
|
# pylint: disable=g-multiple-import, g-doc-args, g-short-docstring-punctuation
|
||||||
|
# pylint: disable=g-no-space-after-docstring-summary
|
||||||
|
from iodine.modules.distributions import FlatParameters
|
||||||
|
from iodine.modules.utils import flatten_all_but_last, get_act_func
|
||||||
|
import numpy as np
|
||||||
|
import shapeguard
|
||||||
|
import sonnet as snt
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
|
||||||
|
class CNN(snt.AbstractModule):
|
||||||
|
"""ConvNet2D followed by an MLP.
|
||||||
|
|
||||||
|
This is a typical encoder architecture for VAEs, and has been found to work
|
||||||
|
well. One small improvement is to append coordinate channels on the input,
|
||||||
|
though for most datasets the improvement obtained is negligible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cnn_opt, mlp_opt, mode="flatten", name="cnn"):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cnn_opt: Dictionary. Kwargs for the cnn. See vae_lib.ConvNet2D for
|
||||||
|
details.
|
||||||
|
mlp_opt: Dictionary. Kwargs for the mlp. See vae_lib.MLP for details.
|
||||||
|
name: String. Optional name.
|
||||||
|
"""
|
||||||
|
super().__init__(name=name)
|
||||||
|
if "activation" in cnn_opt:
|
||||||
|
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
|
||||||
|
self._cnn_opt = cnn_opt
|
||||||
|
|
||||||
|
if "activation" in mlp_opt:
|
||||||
|
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
|
||||||
|
self._mlp_opt = mlp_opt
|
||||||
|
|
||||||
|
self._mode = mode
|
||||||
|
|
||||||
|
def set_output_shapes(self, shape):
|
||||||
|
# assert self._mlp_opt['output_sizes'][-1] is None, self._mlp_opt
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
sg.guard(shape, "1, Y")
|
||||||
|
self._mlp_opt["output_sizes"][-1] = sg.Y
|
||||||
|
|
||||||
|
def _build(self, image):
|
||||||
|
"""Connect model to TensorFlow graph."""
|
||||||
|
assert self._mlp_opt["output_sizes"][-1] is not None, "set output_shapes"
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
flat_image, unflatten = flatten_all_but_last(image, n_dims=3)
|
||||||
|
sg.guard(flat_image, "B, H, W, C")
|
||||||
|
|
||||||
|
cnn = snt.nets.ConvNet2D(
|
||||||
|
activate_final=True,
|
||||||
|
paddings=("SAME",),
|
||||||
|
normalize_final=False,
|
||||||
|
**self._cnn_opt)
|
||||||
|
mlp = snt.nets.MLP(**self._mlp_opt)
|
||||||
|
|
||||||
|
# run CNN
|
||||||
|
net = cnn(flat_image)
|
||||||
|
|
||||||
|
if self._mode == "flatten":
|
||||||
|
# flatten
|
||||||
|
net_shape = net.get_shape().as_list()
|
||||||
|
flat_shape = net_shape[:-3] + [np.prod(net_shape[-3:])]
|
||||||
|
net = tf.reshape(net, flat_shape)
|
||||||
|
elif self._mode == "avg_pool":
|
||||||
|
net = tf.reduce_mean(net, axis=[1, 2])
|
||||||
|
else:
|
||||||
|
raise KeyError('Unknown mode "{}"'.format(self._mode))
|
||||||
|
# run MLP
|
||||||
|
output = sg.guard(mlp(net), "B, Y")
|
||||||
|
return FlatParameters(unflatten(output))
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(snt.AbstractModule):
|
||||||
|
"""MLP."""
|
||||||
|
|
||||||
|
def __init__(self, name="mlp", **mlp_opt):
|
||||||
|
super().__init__(name=name)
|
||||||
|
if "activation" in mlp_opt:
|
||||||
|
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
|
||||||
|
self._mlp_opt = mlp_opt
|
||||||
|
assert mlp_opt["output_sizes"][-1] is None, mlp_opt
|
||||||
|
|
||||||
|
def set_output_shapes(self, shape):
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
sg.guard(shape, "1, Y")
|
||||||
|
self._mlp_opt["output_sizes"][-1] = sg.Y
|
||||||
|
|
||||||
|
def _build(self, data):
|
||||||
|
"""Connect model to TensorFlow graph."""
|
||||||
|
assert self._mlp_opt["output_sizes"][-1] is not None, "set output_shapes"
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
flat_data, unflatten = flatten_all_but_last(data)
|
||||||
|
sg.guard(flat_data, "B, N")
|
||||||
|
|
||||||
|
mlp = snt.nets.MLP(**self._mlp_opt)
|
||||||
|
# run MLP
|
||||||
|
output = sg.guard(mlp(flat_data), "B, Y")
|
||||||
|
return FlatParameters(unflatten(output))
|
||||||
|
|
||||||
|
|
||||||
|
class DeConv(snt.AbstractModule):
|
||||||
|
"""MLP followed by Deconv net.
|
||||||
|
|
||||||
|
This decoder is commonly used by vanilla VAE models. However, in practice
|
||||||
|
BroadcastConv (see below) seems to disentangle slightly better.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mlp_opt, cnn_opt, name="deconv"):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mlp_opt: Dictionary. Kwargs for vae_lib.MLP.
|
||||||
|
cnn_opt: Dictionary. Kwargs for vae_lib.ConvNet2D for the CNN.
|
||||||
|
name: Optional name.
|
||||||
|
"""
|
||||||
|
super().__init__(name=name)
|
||||||
|
assert cnn_opt["output_channels"][-1] is None, cnn_opt
|
||||||
|
if "activation" in cnn_opt:
|
||||||
|
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
|
||||||
|
self._cnn_opt = cnn_opt
|
||||||
|
|
||||||
|
if mlp_opt and "activation" in mlp_opt:
|
||||||
|
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
|
||||||
|
self._mlp_opt = mlp_opt
|
||||||
|
self._target_out_shape = None
|
||||||
|
|
||||||
|
def set_output_shapes(self, shape):
|
||||||
|
self._target_out_shape = shape
|
||||||
|
self._cnn_opt["output_channels"][-1] = self._target_out_shape[-1]
|
||||||
|
|
||||||
|
def _build(self, z):
|
||||||
|
"""Connect model to TensorFlow graph."""
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
flat_z, unflatten = flatten_all_but_last(z)
|
||||||
|
sg.guard(flat_z, "B, Z")
|
||||||
|
sg.guard(self._target_out_shape, "H, W, C")
|
||||||
|
mlp = snt.nets.MLP(**self._mlp_opt)
|
||||||
|
cnn = snt.nets.ConvNet2DTranspose(
|
||||||
|
paddings=("SAME",), normalize_final=False, **self._cnn_opt)
|
||||||
|
net = mlp(flat_z)
|
||||||
|
output = sg.guard(cnn(net), "B, H, W, C")
|
||||||
|
return FlatParameters(unflatten(output))
|
||||||
|
|
||||||
|
|
||||||
|
class BroadcastConv(snt.AbstractModule):
|
||||||
|
"""MLP followed by a broadcast convolution.
|
||||||
|
|
||||||
|
This decoder takes a latent vector z, (optionally) applies an MLP to it,
|
||||||
|
then tiles the resulting vector across space to have dimension [B, H, W, C]
|
||||||
|
i.e. tiles across H and W. Then coordinate channels are appended and a
|
||||||
|
convolutional layer is applied.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cnn_opt,
|
||||||
|
mlp_opt=None,
|
||||||
|
coord_type="linear",
|
||||||
|
coord_freqs=3,
|
||||||
|
name="broadcast_conv",
|
||||||
|
):
|
||||||
|
"""Args:
|
||||||
|
cnn_opt: dict Kwargs for vae_lib.ConvNet2D for the CNN.
|
||||||
|
mlp_opt: None or dict If dictionary, then kwargs for snt.nets.MLP. If
|
||||||
|
None, then the model will not process the latent vector by an mlp.
|
||||||
|
coord_type: ["linear", "cos", None] type of coordinate channels to
|
||||||
|
add.
|
||||||
|
None: add no coordinate channels.
|
||||||
|
linear: two channels with values linearly spaced from -1. to 1. in
|
||||||
|
the H and W dimension respectively.
|
||||||
|
cos: coord_freqs^2 many channels containing cosine basis functions.
|
||||||
|
coord_freqs: int number of frequencies used to construct the cosine
|
||||||
|
basis functions (only for coord_type=="cos")
|
||||||
|
name: Optional name.
|
||||||
|
"""
|
||||||
|
super().__init__(name=name)
|
||||||
|
|
||||||
|
assert cnn_opt["output_channels"][-1] is None, cnn_opt
|
||||||
|
if "activation" in cnn_opt:
|
||||||
|
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
|
||||||
|
self._cnn_opt = cnn_opt
|
||||||
|
|
||||||
|
if mlp_opt and "activation" in mlp_opt:
|
||||||
|
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
|
||||||
|
self._mlp_opt = mlp_opt
|
||||||
|
|
||||||
|
self._target_out_shape = None
|
||||||
|
self._coord_type = coord_type
|
||||||
|
self._coord_freqs = coord_freqs
|
||||||
|
|
||||||
|
def set_output_shapes(self, shape):
|
||||||
|
self._target_out_shape = shape
|
||||||
|
self._cnn_opt["output_channels"][-1] = self._target_out_shape[-1]
|
||||||
|
|
||||||
|
def _build(self, z):
|
||||||
|
"""Connect model to TensorFlow graph."""
|
||||||
|
assert self._target_out_shape is not None, "Call set_output_shape"
|
||||||
|
# reshape components into batch dimension before processing them
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
flat_z, unflatten = flatten_all_but_last(z)
|
||||||
|
sg.guard(flat_z, "B, Z")
|
||||||
|
sg.guard(self._target_out_shape, "H, W, C")
|
||||||
|
|
||||||
|
if self._mlp_opt is None:
|
||||||
|
mlp = tf.identity
|
||||||
|
else:
|
||||||
|
mlp = snt.nets.MLP(activate_final=True, **self._mlp_opt)
|
||||||
|
mlp_output = sg.guard(mlp(flat_z), "B, hidden")
|
||||||
|
|
||||||
|
# tile MLP output spatially and append coordinate channels
|
||||||
|
broadcast_mlp_output = tf.tile(
|
||||||
|
mlp_output[:, tf.newaxis, tf.newaxis],
|
||||||
|
multiples=tf.constant(sg["1, H, W, 1"]),
|
||||||
|
) # B, H, W, Z
|
||||||
|
|
||||||
|
dec_cnn_inputs = self.append_coordinate_channels(broadcast_mlp_output)
|
||||||
|
|
||||||
|
cnn = snt.nets.ConvNet2D(
|
||||||
|
paddings=("SAME",), normalize_final=False, **self._cnn_opt)
|
||||||
|
cnn_outputs = cnn(dec_cnn_inputs)
|
||||||
|
sg.guard(cnn_outputs, "B, H, W, C")
|
||||||
|
|
||||||
|
return FlatParameters(unflatten(cnn_outputs))
|
||||||
|
|
||||||
|
def append_coordinate_channels(self, output):
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
sg.guard(output, "B, H, W, C")
|
||||||
|
if self._coord_type is None:
|
||||||
|
return output
|
||||||
|
if self._coord_type == "linear":
|
||||||
|
w_coords = tf.linspace(-1.0, 1.0, sg.W)[None, None, :, None]
|
||||||
|
h_coords = tf.linspace(-1.0, 1.0, sg.H)[None, :, None, None]
|
||||||
|
w_coords = tf.tile(w_coords, sg["B, H, 1, 1"])
|
||||||
|
h_coords = tf.tile(h_coords, sg["B, 1, W, 1"])
|
||||||
|
return tf.concat([output, h_coords, w_coords], axis=-1)
|
||||||
|
elif self._coord_type == "cos":
|
||||||
|
freqs = sg.guard(tf.range(0.0, self._coord_freqs), "F")
|
||||||
|
valx = tf.linspace(0.0, np.pi, sg.W)[None, None, :, None, None]
|
||||||
|
valy = tf.linspace(0.0, np.pi, sg.H)[None, :, None, None, None]
|
||||||
|
x_basis = tf.cos(valx * freqs[None, None, None, :, None])
|
||||||
|
y_basis = tf.cos(valy * freqs[None, None, None, None, :])
|
||||||
|
xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"])
|
||||||
|
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[Ellipsis, 1:]
|
||||||
|
return tf.concat([output, coords], axis=-1)
|
||||||
|
else:
|
||||||
|
raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
|
||||||
|
|
||||||
|
|
||||||
|
class LSTM(snt.RNNCore):
|
||||||
|
"""Wrapper around snt.LSTM that supports multi-layers and runs K components in
|
||||||
|
parallel.
|
||||||
|
|
||||||
|
Expects input data of shape (B, K, H) and outputs data of shape (B, K, Y)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_sizes, name="lstm"):
|
||||||
|
super().__init__(name=name)
|
||||||
|
self._hidden_sizes = hidden_sizes
|
||||||
|
with self._enter_variable_scope():
|
||||||
|
self._lstm_layers = [snt.LSTM(hidden_size=h) for h in self._hidden_sizes]
|
||||||
|
|
||||||
|
def initial_state(self, batch_size, **kwargs):
|
||||||
|
return [
|
||||||
|
lstm.initial_state(batch_size, **kwargs) for lstm in self._lstm_layers
|
||||||
|
]
|
||||||
|
|
||||||
|
def _build(self, data, prev_states):
|
||||||
|
assert not self._hidden_sizes or self._hidden_sizes[-1] is not None
|
||||||
|
assert len(prev_states) == len(self._hidden_sizes)
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
sg.guard(data, "B, K, H")
|
||||||
|
data = sg.reshape(data, "B*K, H")
|
||||||
|
|
||||||
|
out = data
|
||||||
|
new_states = []
|
||||||
|
for lstm, pstate in zip(self._lstm_layers, prev_states):
|
||||||
|
out, nstate = lstm(out, pstate)
|
||||||
|
new_states.append(nstate)
|
||||||
|
|
||||||
|
sg.guard(out, "B*K, Y")
|
||||||
|
out = sg.reshape(out, "B, K, Y")
|
||||||
|
return out, new_states
|
||||||
@@ -0,0 +1,226 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Plotting tools for IODINE."""
|
||||||
|
# pylint: disable=unused-import, missing-docstring, unused-variable
|
||||||
|
# pylint: disable=invalid-name, unexpected-keyword-arg
|
||||||
|
import functools
|
||||||
|
from iodine.modules.utils import get_mask_plot_colors
|
||||||
|
from matplotlib.colors import hsv_to_rgb
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = ("get_mask_plot_colors", "example_plot", "iterations_plot",
|
||||||
|
"inputs_plot")
|
||||||
|
|
||||||
|
|
||||||
|
def clean_ax(ax, color=None, lw=4.0):
|
||||||
|
ax.set_xticks([])
|
||||||
|
ax.set_yticks([])
|
||||||
|
if color is not None:
|
||||||
|
for spine in ax.spines.values():
|
||||||
|
spine.set_linewidth(lw)
|
||||||
|
spine.set_color(color)
|
||||||
|
|
||||||
|
|
||||||
|
def optional_ax(fn):
|
||||||
|
|
||||||
|
def _wrapped(*args, **kwargs):
|
||||||
|
if kwargs.get("ax", None) is None:
|
||||||
|
figsize = kwargs.pop("figsize", (4, 4))
|
||||||
|
fig, ax = plt.subplots(figsize=figsize)
|
||||||
|
kwargs["ax"] = ax
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return _wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def optional_clean_ax(fn):
|
||||||
|
|
||||||
|
def _wrapped(*args, **kwargs):
|
||||||
|
if kwargs.get("ax", None) is None:
|
||||||
|
figsize = kwargs.pop("figsize", (4, 4))
|
||||||
|
fig, ax = plt.subplots(figsize=figsize)
|
||||||
|
kwargs["ax"] = ax
|
||||||
|
color = kwargs.pop("color", None)
|
||||||
|
lw = kwargs.pop("lw", 4.0)
|
||||||
|
res = fn(*args, **kwargs)
|
||||||
|
clean_ax(kwargs["ax"], color, lw)
|
||||||
|
return res
|
||||||
|
|
||||||
|
return _wrapped
|
||||||
|
|
||||||
|
|
||||||
|
@optional_clean_ax
|
||||||
|
def show_img(img, mask=None, ax=None, norm=False):
|
||||||
|
if norm:
|
||||||
|
vmin, vmax = np.min(img), np.max(img)
|
||||||
|
img = (img - vmin) / (vmax - vmin)
|
||||||
|
if mask is not None:
|
||||||
|
img = img * mask + np.ones_like(img) * (1.0 - mask)
|
||||||
|
|
||||||
|
return ax.imshow(img.clip(0.0, 1.0), interpolation="nearest")
|
||||||
|
|
||||||
|
|
||||||
|
@optional_clean_ax
|
||||||
|
def show_mask(m, ax):
|
||||||
|
color_conv = get_mask_plot_colors(m.shape[0])
|
||||||
|
color_mask = np.dot(np.transpose(m, [1, 2, 0]), color_conv)
|
||||||
|
return ax.imshow(color_mask.clip(0.0, 1.0), interpolation="nearest")
|
||||||
|
|
||||||
|
|
||||||
|
@optional_clean_ax
|
||||||
|
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
|
||||||
|
return ax.matshow(
|
||||||
|
m[Ellipsis, 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
|
||||||
|
|
||||||
|
|
||||||
|
@optional_clean_ax
|
||||||
|
def show_coords(m, ax):
|
||||||
|
vmin, vmax = np.min(m), np.max(m)
|
||||||
|
m = (m - vmin) / (vmax - vmin)
|
||||||
|
color_conv = get_mask_plot_colors(m.shape[-1])
|
||||||
|
color_mask = np.dot(m, color_conv)
|
||||||
|
return ax.imshow(color_mask, interpolation="nearest")
|
||||||
|
|
||||||
|
|
||||||
|
def example_plot(rinfo,
|
||||||
|
b=0,
|
||||||
|
t=-1,
|
||||||
|
mask_components=False,
|
||||||
|
size=2,
|
||||||
|
column_titles=True):
|
||||||
|
image = rinfo["data"]["image"][b, 0]
|
||||||
|
recons = rinfo["outputs"]["recons"][b, t, 0]
|
||||||
|
pred_mask = rinfo["outputs"]["pred_mask"][b, t]
|
||||||
|
components = rinfo["outputs"]["components"][b, t]
|
||||||
|
|
||||||
|
K, H, W, C = components.shape
|
||||||
|
colors = get_mask_plot_colors(K)
|
||||||
|
|
||||||
|
nrows = 1
|
||||||
|
ncols = 3 + K
|
||||||
|
fig, axes = plt.subplots(ncols=ncols, figsize=(ncols * size, nrows * size))
|
||||||
|
|
||||||
|
show_img(image, ax=axes[0], color="#000000")
|
||||||
|
show_img(recons, ax=axes[1], color="#000000")
|
||||||
|
show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000")
|
||||||
|
for k in range(K):
|
||||||
|
mask = pred_mask[k] if mask_components else None
|
||||||
|
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
|
||||||
|
|
||||||
|
if column_titles:
|
||||||
|
labels = ["Image", "Recons.", "Mask"
|
||||||
|
] + ["Component {}".format(k + 1) for k in range(K)]
|
||||||
|
for ax, title in zip(axes, labels):
|
||||||
|
ax.set_title(title)
|
||||||
|
plt.subplots_adjust(hspace=0.03, wspace=0.035)
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def iterations_plot(rinfo, b=0, mask_components=False, size=2):
|
||||||
|
image = rinfo["data"]["image"][b]
|
||||||
|
true_mask = rinfo["data"]["true_mask"][b]
|
||||||
|
recons = rinfo["outputs"]["recons"][b]
|
||||||
|
pred_mask = rinfo["outputs"]["pred_mask"][b]
|
||||||
|
pred_mask_logits = rinfo["outputs"]["pred_mask_logits"][b]
|
||||||
|
components = rinfo["outputs"]["components"][b]
|
||||||
|
|
||||||
|
T, K, H, W, C = components.shape
|
||||||
|
colors = get_mask_plot_colors(K)
|
||||||
|
nrows = T + 1
|
||||||
|
ncols = 2 + K
|
||||||
|
fig, axes = plt.subplots(
|
||||||
|
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
|
||||||
|
for t in range(T):
|
||||||
|
show_img(recons[t, 0], ax=axes[t, 0])
|
||||||
|
show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1])
|
||||||
|
axes[t, 0].set_ylabel("iter {}".format(t))
|
||||||
|
for k in range(K):
|
||||||
|
mask = pred_mask[t, k] if mask_components else None
|
||||||
|
show_img(components[t, k], ax=axes[t, k + 2], color=colors[k], mask=mask)
|
||||||
|
|
||||||
|
axes[0, 0].set_title("Reconstruction")
|
||||||
|
axes[0, 1].set_title("Mask")
|
||||||
|
show_img(image[0], ax=axes[T, 0])
|
||||||
|
show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1])
|
||||||
|
vmin = np.min(pred_mask_logits[T - 1])
|
||||||
|
vmax = np.max(pred_mask_logits[T - 1])
|
||||||
|
|
||||||
|
for k in range(K):
|
||||||
|
axes[0, k + 2].set_title("Component {}".format(k + 1)) # , color=colors[k])
|
||||||
|
show_mat(
|
||||||
|
pred_mask_logits[T - 1, k], ax=axes[T, k + 2], vmin=vmin, vmax=vmax)
|
||||||
|
axes[T, k + 2].set_xlabel(
|
||||||
|
"Mask Logits for\nComponent {}".format(k + 1)) # , color=colors[k])
|
||||||
|
axes[T, 0].set_xlabel("Input Image")
|
||||||
|
axes[T, 1].set_xlabel("Ground Truth Mask")
|
||||||
|
|
||||||
|
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
||||||
|
return fig
|
||||||
|
|
||||||
|
|
||||||
|
def inputs_plot(rinfo, b=0, t=0, size=2):
|
||||||
|
B, T, K, H, W, C = rinfo["outputs"]["components"].shape
|
||||||
|
colors = get_mask_plot_colors(K)
|
||||||
|
inputs = rinfo["inputs"]["spatial"]
|
||||||
|
rows = [
|
||||||
|
("image", show_img, False),
|
||||||
|
("components", show_img, False),
|
||||||
|
("dcomponents", functools.partial(show_img, norm=True), False),
|
||||||
|
("mask", show_mat, True),
|
||||||
|
("pred_mask", show_mat, True),
|
||||||
|
("dmask", functools.partial(show_mat, cmap="coolwarm"), True),
|
||||||
|
("posterior", show_mat, True),
|
||||||
|
("log_prob", show_mat, True),
|
||||||
|
("counterfactual", show_mat, True),
|
||||||
|
("coordinates", show_coords, False),
|
||||||
|
]
|
||||||
|
rows = [(n, f, mcb) for n, f, mcb in rows if n in inputs]
|
||||||
|
nrows = len(rows)
|
||||||
|
ncols = K + 1
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(
|
||||||
|
nrows=nrows,
|
||||||
|
ncols=ncols,
|
||||||
|
figsize=(ncols * size - size * 0.9, nrows * size),
|
||||||
|
gridspec_kw={"width_ratios": [1] * K + [0.1]},
|
||||||
|
)
|
||||||
|
for r, (name, plot_fn, make_cbar) in enumerate(rows):
|
||||||
|
axes[r, 0].set_ylabel(name)
|
||||||
|
if make_cbar:
|
||||||
|
vmin = np.min(inputs[name][b, t])
|
||||||
|
vmax = np.max(inputs[name][b, t])
|
||||||
|
if np.abs(vmin - vmax) < 1e-6:
|
||||||
|
vmin -= 0.1
|
||||||
|
vmax += 0.1
|
||||||
|
plot_fn = functools.partial(plot_fn, vmin=vmin, vmax=vmax)
|
||||||
|
# print("range of {:<16}: [{:0.2f}, {:0.2f}]".format(name, vmin, vmax))
|
||||||
|
for k in range(K):
|
||||||
|
if inputs[name].shape[2] == 1:
|
||||||
|
m = inputs[name][b, t, 0]
|
||||||
|
color = (0.0, 0.0, 0.0)
|
||||||
|
else:
|
||||||
|
m = inputs[name][b, t, k]
|
||||||
|
color = colors[k]
|
||||||
|
mappable = plot_fn(m, ax=axes[r, k], color=color)
|
||||||
|
if make_cbar:
|
||||||
|
fig.colorbar(mappable, cax=axes[r, K])
|
||||||
|
else:
|
||||||
|
axes[r, K].set_visible(False)
|
||||||
|
for k in range(K):
|
||||||
|
axes[0, k].set_title("Component {}".format(k + 1)) # , color=colors[k])
|
||||||
|
|
||||||
|
plt.subplots_adjust(hspace=0.05, wspace=0.05)
|
||||||
|
return fig
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
# Lint as: python3
|
||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""Iterative refinement modules."""
|
||||||
|
# pylint: disable=g-doc-bad-indent, unused-variable
|
||||||
|
from iodine.modules import utils
|
||||||
|
import shapeguard
|
||||||
|
import sonnet as snt
|
||||||
|
import tensorflow.compat.v1 as tf
|
||||||
|
|
||||||
|
|
||||||
|
class RefinementCore(snt.RNNCore):
|
||||||
|
"""Recurrent Refinement Module.
|
||||||
|
|
||||||
|
Refinement modules take as inputs:
|
||||||
|
* previous state (which could be an arbitrary nested structure)
|
||||||
|
* current inputs which include
|
||||||
|
* image-space inputs like pixel-based errors, or mask-posteriors
|
||||||
|
* latent-space inputs like the previous z_dist, or dz
|
||||||
|
|
||||||
|
They use these inputs to produce:
|
||||||
|
* output (usually a new z_dist)
|
||||||
|
* new_state
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
encoder_net,
|
||||||
|
recurrent_net,
|
||||||
|
refinement_head,
|
||||||
|
name="refinement"):
|
||||||
|
super().__init__(name=name)
|
||||||
|
self._encoder_net = encoder_net
|
||||||
|
self._recurrent_net = recurrent_net
|
||||||
|
self._refinement_head = refinement_head
|
||||||
|
self._sg = shapeguard.ShapeGuard()
|
||||||
|
|
||||||
|
def initial_state(self, batch_size, **unused_kwargs):
|
||||||
|
return self._recurrent_net.initial_state(batch_size)
|
||||||
|
|
||||||
|
def _build(self, inputs, prev_state):
|
||||||
|
sg = self._sg
|
||||||
|
assert "spatial" in inputs, inputs.keys()
|
||||||
|
assert "flat" in inputs, inputs.keys()
|
||||||
|
assert "zp" in inputs["flat"], inputs["flat"].keys()
|
||||||
|
zp = sg.guard(inputs["flat"]["zp"], "B, K, Zp")
|
||||||
|
|
||||||
|
x = sg.guard(self.prepare_spatial_inputs(inputs["spatial"]), "B*K, H, W, C")
|
||||||
|
h1 = sg.guard(self._encoder_net(x).params, "B*K, H1")
|
||||||
|
h2 = sg.guard(self.prepare_flat_inputs(h1, inputs["flat"]), "B*K, H2")
|
||||||
|
h2_unflattened = sg.reshape(h2, "B, K, H2")
|
||||||
|
h3, next_state = self._recurrent_net(h2_unflattened, prev_state)
|
||||||
|
sg.guard(h3, "B, K, H3")
|
||||||
|
outputs = sg.guard(self._refinement_head(zp, h3), "B, K, Y")
|
||||||
|
|
||||||
|
del self._sg.B
|
||||||
|
return outputs, next_state
|
||||||
|
|
||||||
|
def prepare_spatial_inputs(self, inputs):
|
||||||
|
values = []
|
||||||
|
for name, val in sorted(inputs.items(), key=lambda it: it[0]):
|
||||||
|
if val.shape.as_list()[1] == 1:
|
||||||
|
self._sg.guard(val, "B, 1, H, W, _C")
|
||||||
|
val = tf.tile(val, self._sg["1, K, 1, 1, 1"])
|
||||||
|
else:
|
||||||
|
self._sg.guard(val, "B, K, H, W, _C")
|
||||||
|
values.append(val)
|
||||||
|
concat_inputs = self._sg.guard(tf.concat(values, axis=-1), "B, K, H, W, C")
|
||||||
|
return self._sg.reshape(concat_inputs, "B*K, H, W, C")
|
||||||
|
|
||||||
|
def prepare_flat_inputs(self, hidden, inputs):
|
||||||
|
values = [self._sg.guard(hidden, "B*K, H1")]
|
||||||
|
|
||||||
|
for name, val in sorted(inputs.items(), key=lambda it: it[0]):
|
||||||
|
self._sg.guard(val, "B, K, _")
|
||||||
|
val_flat = tf.reshape(val, self._sg["B*K"] + [-1])
|
||||||
|
values.append(val_flat)
|
||||||
|
return tf.concat(values, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class ResHead(snt.AbstractModule):
|
||||||
|
"""Updates Zp using a residual mechanism."""
|
||||||
|
|
||||||
|
def __init__(self, name="residual_head"):
|
||||||
|
super().__init__(name=name)
|
||||||
|
|
||||||
|
def _build(self, zp_old, inputs):
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
sg.guard(zp_old, "B, K, Zp")
|
||||||
|
sg.guard(inputs, "B, K, H")
|
||||||
|
update = snt.Linear(sg.Zp)
|
||||||
|
|
||||||
|
flat_zp = sg.reshape(zp_old, "B*K, Zp")
|
||||||
|
flat_inputs = sg.reshape(inputs, "B*K, H")
|
||||||
|
|
||||||
|
zp = flat_zp + update(flat_inputs)
|
||||||
|
|
||||||
|
return sg.reshape(zp, "B, K, Zp")
|
||||||
|
|
||||||
|
|
||||||
|
class PredictorCorrectorHead(snt.AbstractModule):
|
||||||
|
"""This refinement head is used for sequential data.
|
||||||
|
|
||||||
|
At every step it computes a prediction from the λ of the previous timestep
|
||||||
|
and an update from the refinement network of the current timestep.
|
||||||
|
|
||||||
|
The next step λ' is computed as a gated combination of both:
|
||||||
|
λ' = g * λ_corr + (1-g) * λ_pred
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_sizes=(64,),
|
||||||
|
pred_gate_bias=0.0,
|
||||||
|
corrector_gate_bias=0.0,
|
||||||
|
activation=tf.nn.elu,
|
||||||
|
name="predcorr_head",
|
||||||
|
):
|
||||||
|
super().__init__(name=name)
|
||||||
|
self._hidden_sizes = hidden_sizes
|
||||||
|
self._activation = utils.get_act_func(activation)
|
||||||
|
self._pred_gate_bias = pred_gate_bias
|
||||||
|
self._corrector_gate_bias = corrector_gate_bias
|
||||||
|
|
||||||
|
def _build(self, zp_old, inputs):
|
||||||
|
sg = shapeguard.ShapeGuard()
|
||||||
|
sg.guard(zp_old, "B, K, Zp")
|
||||||
|
sg.guard(inputs, "B, K, H")
|
||||||
|
update = snt.Linear(sg.Zp)
|
||||||
|
update_gate = snt.Linear(sg.Zp)
|
||||||
|
predict = snt.nets.MLP(
|
||||||
|
output_sizes=list(self._hidden_sizes) + [sg.Zp * 2],
|
||||||
|
activation=self._activation,
|
||||||
|
)
|
||||||
|
|
||||||
|
flat_zp = sg.reshape(zp_old, "B*K, Zp")
|
||||||
|
flat_inputs = sg.reshape(inputs, "B*K, H")
|
||||||
|
|
||||||
|
g = tf.nn.sigmoid(update_gate(flat_inputs) + self._corrector_gate_bias)
|
||||||
|
u = update(flat_inputs)
|
||||||
|
|
||||||
|
# a slightly more efficient way of computing the gated update
|
||||||
|
# (1-g) * flat_zp + g * u
|
||||||
|
zp_corrected = flat_zp + g * (u - flat_zp)
|
||||||
|
|
||||||
|
predicted = predict(flat_zp)
|
||||||
|
pred_up = predicted[:, :sg.Zp]
|
||||||
|
pred_gate = tf.nn.sigmoid(predicted[:, sg.Zp:] + self._pred_gate_bias)
|
||||||
|
|
||||||
|
zp = zp_corrected + pred_gate * (pred_up - zp_corrected)
|
||||||
|
|
||||||
|
return sg.reshape(zp, "B, K, Zp")
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,9 @@
|
|||||||
|
tensorflow-gpu==1.14.0
|
||||||
|
tensorflow-probability==0.7.0
|
||||||
|
dm-sonnet==1.35
|
||||||
|
sacred>=0.7,<0.8
|
||||||
|
shapeguard
|
||||||
|
seaborn
|
||||||
|
pymongo
|
||||||
|
jupyterlab
|
||||||
|
git+git://github.com/deepmind/multi_object_datasets.git
|
||||||
Executable
+36
@@ -0,0 +1,36 @@
|
|||||||
|
#!/bin/sh
|
||||||
|
# Copyright 2019 Deepmind Technologies Limited.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "downloading checkpoints from GCP"
|
||||||
|
iodine/download_checkpoints.sh
|
||||||
|
|
||||||
|
python3 -m venv iodine_venv
|
||||||
|
source iodine_venv/bin/activate
|
||||||
|
pip3 install --upgrade setuptools wheel
|
||||||
|
pip3 install -r iodine/requirements.txt
|
||||||
|
|
||||||
|
# Get some fake data and put it where the real multi_objects_dataset files live.
|
||||||
|
mkdir -p iodine/multi_object_datasets
|
||||||
|
cp iodine/test_data/tetrominoes_mini.tfrecords iodine/multi_object_datasets/tetrominoes_train.tfrecords
|
||||||
|
|
||||||
|
# Run training with a cut down size.
|
||||||
|
python3 -m iodine.main \
|
||||||
|
-f with tetrominoes \
|
||||||
|
data.shuffle_buffer=2 \
|
||||||
|
data.batch_size=2 \
|
||||||
|
n_z=4 \
|
||||||
|
num_components=3 \
|
||||||
|
stop_after_steps=11
|
||||||
Binary file not shown.
Reference in New Issue
Block a user