mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-26 17:35:21 +08:00
Release of IODINE
PiperOrigin-RevId: 299101887
This commit is contained in:
@@ -10,6 +10,7 @@ env:
|
||||
matrix:
|
||||
- PROJECT="tvt"
|
||||
- PROJECT="cs_gan"
|
||||
- PROJECT="iodine"
|
||||
- PROJECT="transporter"
|
||||
before_script:
|
||||
- sudo apt-get update -qq
|
||||
|
||||
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
|
||||
|
||||
## Projects
|
||||
|
||||
* [Multi-Object Representation Learning with Iterative Variational Inference (IODINE)](iodine)
|
||||
* [AlphaFold CASP13](alphafold_casp13), Nature 2020
|
||||
* [Unrestricted Adversarial Challenge](unrestricted_advx)
|
||||
* [Hierarchical Probabilistic U-Net (HPU-Net)](hierarchical_probabilistic_unet)
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -0,0 +1,142 @@
|
||||
# IODINE
|
||||
Reference implementation for the paper ["Multi-Object Representation Learning with Iterative Variational Inference"](https://arxiv.org/abs/1903.00450).
|
||||
This repository contains:
|
||||
|
||||
* An IODINE implementation in Tensorflow v1.
|
||||
* Configurations used in the paper (checkpoints available in Cloud Storage) for:
|
||||
* CLEVR
|
||||
* Multi-dSprites
|
||||
* Tetrominoes
|
||||
* A notebook for running and inspecting the model and plotting the results
|
||||
|
||||
|
||||
## Installation
|
||||
1. Clone the DeepMind research repository:
|
||||
|
||||
``` bash
|
||||
git clone https://github.com/deepmind/deepmind-research.git
|
||||
cd deepmind-research
|
||||
```
|
||||
|
||||
2. Download the checkpoints from GCP. A shell script is provided:
|
||||
|
||||
```bash
|
||||
./iodine/download_checkpoints.sh
|
||||
```
|
||||
|
||||
On platforms without wget, the files can be downloaded from [this webpage](https://console.cloud.google.com/storage/browser/deepmind-research-iodine?pli=1)
|
||||
and the unzipped `checkpoints/` folder should be placed in
|
||||
`deepmind-research/iodine/checkpoints`.
|
||||
|
||||
|
||||
3. Prepare a Python 3 environment - virtualenv is recommended.
|
||||
|
||||
```bash
|
||||
python3 -m venv iodine_venv
|
||||
source iodine_venv/bin/activate
|
||||
```
|
||||
|
||||
4. Install dependencies:
|
||||
|
||||
```bash
|
||||
pip3 install -r iodine/requirements.txt
|
||||
```
|
||||
|
||||
5. The `multi_object_datasets` package installed via requirements.txt provides python code to open the data files, but not the data files themselves.
|
||||
Download the desired datasets either manually from the [Google Cloud Storage](https://console.cloud.google.com/storage/browser/multi-object-datasets) or using the commands below:
|
||||
|
||||
```bash
|
||||
pushd iodine/multi_object_datasets
|
||||
# CLEVR
|
||||
wget https://storage.googleapis.com/multi-object-datasets/clevr_with_masks/clevr_with_masks_train.tfrecords
|
||||
# Multi-dSprites
|
||||
wget https://storage.googleapis.com/multi-object-datasets/multi_dsprites/multi_dsprites_colored_on_grayscale.tfrecords
|
||||
# Tetrominoes
|
||||
wget https://storage.googleapis.com/multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords
|
||||
# Get back to location containing 'iodine' directory
|
||||
popd
|
||||
```
|
||||
|
||||
See [multi_object_datasets repository](https://github.com/deepmind/multi_object_datasets)
|
||||
for further details.
|
||||
6. Make sure that you have CUDA 10 and CuDNN 7 installed
|
||||
|
||||
|
||||
## Interact with a Model
|
||||
Use the jupyter notebook `Eval.ipynb` to load and run one of the checkpoints.
|
||||
It also contains code to plot the outputs and latent traversals.
|
||||
|
||||
|
||||
## Train a Model
|
||||
To train your own model use the [Sacred](https://github.com/IDSIA/sacred) experiment defined in `main.py`.
|
||||
The configurations used in the paper for the different datasets are available as [named configs](https://sacred.readthedocs.io/en/latest/configuration.html#named-configurations) inside of `configuration.py`.
|
||||
### Train a new model
|
||||
* CLEVR6
|
||||
|
||||
```bash
|
||||
python3 -m iodine.main -f with clevr6
|
||||
```
|
||||
|
||||
* Multi-dSprites
|
||||
|
||||
```bash
|
||||
python3 -m iodine.main -f with multi_dsprites
|
||||
```
|
||||
|
||||
* Tetrominoes
|
||||
|
||||
```bash
|
||||
python3 -m iodine.main -f with tetrominoes
|
||||
```
|
||||
|
||||
It is recommended to add an observer to your run to let Sacred record the details of run.
|
||||
To add a [FileStorageObserver](https://sacred.readthedocs.io/en/latest/command_line.html#filestorage-observer) add `-F my_storage_dir`, and add `-m my_db_name` for a [MongoObserver](https://sacred.readthedocs.io/en/latest/command_line.html#mongodb-observer).
|
||||
|
||||
### Adjusting Config Values
|
||||
The experiment has a configuration that can be printed and adjusted from the commandline. E.g.:
|
||||
|
||||
``` bash
|
||||
# print configuration
|
||||
python3 -m iodine.main -f print_config with clevr6
|
||||
# run experiment after adjusting batch_size and the size of the shuffle buffer
|
||||
python3 -m iodine.main -f with clevr6 batch_size=2 data.shuffle_buffer=100
|
||||
```
|
||||
|
||||
### Tensorboard
|
||||
Each run stores checkpoints and summaries in the directory specified by `checkpoint_dir`, to which a suffix based on the run_id is appended.
|
||||
If an observer is added the `run_id` is set automatically. Otherwise it should be set manually using e.g. `run_id=5`.
|
||||
|
||||
Summaries can be viewed using tensorboard. E.g. like this for clevr6 (assuming `run_id=1`):
|
||||
|
||||
```bash
|
||||
tensorboard --log-dir iodine/checkpoints/clevr6_1
|
||||
```
|
||||
|
||||
### Continue Previous Run
|
||||
To continue a previous run pass `continue_run=True` and the path of the checkpoints:
|
||||
|
||||
```bash
|
||||
python3 -m iodine.main -f with clevr6 checkpoint_dir=iodine/checkpoints/clevr6_1
|
||||
```
|
||||
|
||||
## Code Structure
|
||||
The main experiment defined in `main.py` uses `sacred` and the configurations for the different datasets are added as named configs and can be found in `configuration.py`.
|
||||
The model implementation can be found in the `modules` directory and is based on `tensorflow` and `sonnet`:
|
||||
|
||||
* `iodine.py` The main IODINE module that assembles the decoder, refinement network, distributions and factor regressor.
|
||||
* `decoder.py` The ComponentDecoder which is a wrapper around networks that takes care of splitting the output channels into means and masks.
|
||||
* `refinement.py` The refinement components assembles the encoder network, LSTM and refinement head.
|
||||
* `networks.py` Different standard networks such as CNN, BroadcastCNN, and LSTM.
|
||||
* `distribution.py` Definition of the latent and pixel distributions.
|
||||
* `factor_eval.py` Contains the factor regressor which predicts the true factors from the inferred object latents.
|
||||
* `data.py` Dataset wrappers around `multi_object_datasets` that take care of shuffling, batching and preprocessing.
|
||||
* `plotting.py` Helper functions for plotting results.
|
||||
* `utils.py` General helper functions.
|
||||
|
||||
|
||||
---
|
||||
**DISCLAIMER**
|
||||
|
||||
This is not an officially supported Google product.
|
||||
|
||||
---
|
||||
@@ -0,0 +1,370 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Configurations for IODINE."""
|
||||
# pylint: disable=missing-docstring, unused-variable
|
||||
import math
|
||||
|
||||
|
||||
def clevr6():
|
||||
n_z = 64 # number of latent dimensions
|
||||
num_components = 7 # number of components (K)
|
||||
num_iters = 5
|
||||
checkpoint_dir = "iodine/checkpoints/clevr6"
|
||||
|
||||
# For the paper we used 8 GPUs with a batch size of 4 each.
|
||||
# This means a total batch size of 32, which is too large for a single GPU.
|
||||
# When reducing the batch size, the learning rate should also be lowered.
|
||||
batch_size = 4
|
||||
learn_rate = 0.001 * math.sqrt(batch_size / 32)
|
||||
|
||||
data = {
|
||||
"constructor": "iodine.modules.data.CLEVR",
|
||||
"batch_size": batch_size,
|
||||
"path": "multi_object_datasets/clevr_with_masks_train.tfrecords",
|
||||
"max_num_objects": 6,
|
||||
}
|
||||
|
||||
model = {
|
||||
"constructor": "iodine.modules.iodine.IODINE",
|
||||
"n_z": n_z,
|
||||
"num_components": num_components,
|
||||
"num_iters": num_iters,
|
||||
"iter_loss_weight": "linspace",
|
||||
"coord_type": "linear",
|
||||
"decoder": {
|
||||
"constructor": "iodine.modules.decoder.ComponentDecoder",
|
||||
"pixel_decoder": {
|
||||
"constructor": "iodine.modules.networks.BroadcastConv",
|
||||
"cnn_opt": {
|
||||
# Final channels is irrelevant with target_output_shape
|
||||
"output_channels": [64, 64, 64, 64, None],
|
||||
"kernel_shapes": [3],
|
||||
"strides": [1],
|
||||
"activation": "elu",
|
||||
},
|
||||
"coord_type": "linear",
|
||||
},
|
||||
},
|
||||
"refinement_core": {
|
||||
"constructor": "iodine.modules.refinement.RefinementCore",
|
||||
"encoder_net": {
|
||||
"constructor": "iodine.modules.networks.CNN",
|
||||
"mode": "avg_pool",
|
||||
"cnn_opt": {
|
||||
"output_channels": [64, 64, 64, 64],
|
||||
"strides": [2],
|
||||
"kernel_shapes": [3],
|
||||
"activation": "elu",
|
||||
},
|
||||
"mlp_opt": {
|
||||
"output_sizes": [256, 256],
|
||||
"activation": "elu"
|
||||
},
|
||||
},
|
||||
"recurrent_net": {
|
||||
"constructor": "iodine.modules.networks.LSTM",
|
||||
"hidden_sizes": [256],
|
||||
},
|
||||
"refinement_head": {
|
||||
"constructor": "iodine.modules.refinement.ResHead"
|
||||
},
|
||||
},
|
||||
"latent_dist": {
|
||||
"constructor": "iodine.modules.distributions.LocScaleDistribution",
|
||||
"dist": "normal",
|
||||
"scale_act": "softplus",
|
||||
"scale": "var",
|
||||
"name": "latent_dist",
|
||||
},
|
||||
"output_dist": {
|
||||
"constructor": "iodine.modules.distributions.MaskedMixture",
|
||||
"num_components": num_components,
|
||||
"component_dist": {
|
||||
"constructor":
|
||||
"iodine.modules.distributions.LocScaleDistribution",
|
||||
"dist":
|
||||
"logistic",
|
||||
"scale":
|
||||
"fixed",
|
||||
"scale_val":
|
||||
0.03,
|
||||
"name":
|
||||
"pixel_distribution",
|
||||
},
|
||||
},
|
||||
"factor_evaluator": {
|
||||
"constructor":
|
||||
"iodine.modules.factor_eval.FactorRegressor",
|
||||
"mapping": [
|
||||
("color", 9, "categorical"),
|
||||
("shape", 4, "categorical"),
|
||||
("size", 3, "categorical"),
|
||||
("position", 3, "scalar"),
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
optimizer = {
|
||||
"constructor": "tensorflow.train.AdamOptimizer",
|
||||
"learning_rate": {
|
||||
"constructor": "tensorflow.train.exponential_decay",
|
||||
"learning_rate": learn_rate,
|
||||
"global_step": {
|
||||
"constructor": "tensorflow.train.get_or_create_global_step"
|
||||
},
|
||||
"decay_steps": 1000000,
|
||||
"decay_rate": 0.1,
|
||||
},
|
||||
"beta1": 0.95,
|
||||
}
|
||||
|
||||
|
||||
def multi_dsprites():
|
||||
n_z = 16 # number of latent dimensions
|
||||
num_components = 6 # number of components (K)
|
||||
num_iters = 5
|
||||
checkpoint_dir = "iodine/checkpoints/multi_dsprites"
|
||||
|
||||
# For the paper we used 8 GPUs with a batch size of 16 each.
|
||||
# This means a total batch size of 128, which is too large for a single GPU.
|
||||
# When reducing the batch size, the learning rate should also be lowered.
|
||||
batch_size = 16
|
||||
learn_rate = 0.0003 * math.sqrt(batch_size / 128)
|
||||
|
||||
data = {
|
||||
"constructor":
|
||||
"iodine.modules.data.MultiDSprites",
|
||||
"batch_size":
|
||||
batch_size,
|
||||
"path":
|
||||
"multi_object_datasets/multi_dsprites_colored_on_grayscale.tfrecords",
|
||||
"dataset_variant":
|
||||
"colored_on_grayscale",
|
||||
"min_num_objs":
|
||||
3,
|
||||
"max_num_objs":
|
||||
3,
|
||||
}
|
||||
|
||||
model = {
|
||||
"constructor": "iodine.modules.iodine.IODINE",
|
||||
"n_z": n_z,
|
||||
"num_components": num_components,
|
||||
"num_iters": num_iters,
|
||||
"iter_loss_weight": "linspace",
|
||||
"coord_type": "cos",
|
||||
"coord_freqs": 3,
|
||||
"decoder": {
|
||||
"constructor": "iodine.modules.decoder.ComponentDecoder",
|
||||
"pixel_decoder": {
|
||||
"constructor": "iodine.modules.networks.BroadcastConv",
|
||||
"cnn_opt": {
|
||||
# Final channels is irrelevant with target_output_shape
|
||||
"output_channels": [32, 32, 32, 32, None],
|
||||
"kernel_shapes": [5],
|
||||
"strides": [1],
|
||||
"activation": "elu",
|
||||
},
|
||||
"coord_type": "linear",
|
||||
},
|
||||
},
|
||||
"refinement_core": {
|
||||
"constructor": "iodine.modules.refinement.RefinementCore",
|
||||
"encoder_net": {
|
||||
"constructor": "iodine.modules.networks.CNN",
|
||||
"mode": "avg_pool",
|
||||
"cnn_opt": {
|
||||
"output_channels": [32, 32, 32],
|
||||
"strides": [2],
|
||||
"kernel_shapes": [5],
|
||||
"activation": "elu",
|
||||
},
|
||||
"mlp_opt": {
|
||||
"output_sizes": [128],
|
||||
"activation": "elu"
|
||||
},
|
||||
},
|
||||
"recurrent_net": {
|
||||
"constructor": "iodine.modules.networks.LSTM",
|
||||
"hidden_sizes": [128],
|
||||
},
|
||||
"refinement_head": {
|
||||
"constructor": "iodine.modules.refinement.ResHead"
|
||||
},
|
||||
},
|
||||
"latent_dist": {
|
||||
"constructor": "iodine.modules.distributions.LocScaleDistribution",
|
||||
"dist": "normal",
|
||||
"scale_act": "softplus",
|
||||
"scale": "var",
|
||||
"name": "latent_dist",
|
||||
},
|
||||
"output_dist": {
|
||||
"constructor": "iodine.modules.distributions.MaskedMixture",
|
||||
"num_components": num_components,
|
||||
"component_dist": {
|
||||
"constructor":
|
||||
"iodine.modules.distributions.LocScaleDistribution",
|
||||
"dist":
|
||||
"logistic",
|
||||
"scale":
|
||||
"fixed",
|
||||
"scale_val":
|
||||
0.03,
|
||||
"name":
|
||||
"pixel_distribution",
|
||||
},
|
||||
},
|
||||
"factor_evaluator": {
|
||||
"constructor":
|
||||
"iodine.modules.factor_eval.FactorRegressor",
|
||||
"mapping": [
|
||||
("color", 3, "scalar"),
|
||||
("shape", 4, "categorical"),
|
||||
("scale", 1, "scalar"),
|
||||
("x", 1, "scalar"),
|
||||
("y", 1, "scalar"),
|
||||
("orientation", 2, "angle"),
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
optimizer = {
|
||||
"constructor": "tensorflow.train.AdamOptimizer",
|
||||
"learning_rate": {
|
||||
"constructor": "tensorflow.train.exponential_decay",
|
||||
"learning_rate": learn_rate,
|
||||
"global_step": {
|
||||
"constructor": "tensorflow.train.get_or_create_global_step"
|
||||
},
|
||||
"decay_steps": 1000000,
|
||||
"decay_rate": 0.1,
|
||||
},
|
||||
"beta1": 0.95,
|
||||
}
|
||||
|
||||
|
||||
def tetrominoes():
|
||||
n_z = 32 # number of latent dimensions
|
||||
num_components = 4 # number of components (K)
|
||||
num_iters = 5
|
||||
checkpoint_dir = "iodine/checkpoints/tetrominoes"
|
||||
|
||||
# For the paper we used 8 GPUs with a batch size of 32 each.
|
||||
# This means a total batch size of 256, which is too large for a single GPU.
|
||||
# When reducing the batch size, the learning rate should also be lowered.
|
||||
batch_size = 128
|
||||
learn_rate = 0.0003 * math.sqrt(batch_size / 256)
|
||||
|
||||
data = {
|
||||
"constructor": "iodine.modules.data.Tetrominoes",
|
||||
"batch_size": batch_size,
|
||||
"path": "iodine/multi_object_datasets/tetrominoes_train.tfrecords",
|
||||
}
|
||||
|
||||
model = {
|
||||
"constructor": "iodine.modules.iodine.IODINE",
|
||||
"n_z": n_z,
|
||||
"num_components": num_components,
|
||||
"num_iters": num_iters,
|
||||
"iter_loss_weight": "linspace",
|
||||
"coord_type": "cos",
|
||||
"coord_freqs": 3,
|
||||
"decoder": {
|
||||
"constructor": "iodine.modules.decoder.ComponentDecoder",
|
||||
"pixel_decoder": {
|
||||
"constructor": "iodine.modules.networks.BroadcastConv",
|
||||
"cnn_opt": {
|
||||
# Final channels is irrelevant with target_output_shape
|
||||
"output_channels": [32, 32, 32, 32, None],
|
||||
"kernel_shapes": [5],
|
||||
"strides": [1],
|
||||
"activation": "elu",
|
||||
},
|
||||
"coord_type": "linear",
|
||||
"coord_freqs": 3,
|
||||
},
|
||||
},
|
||||
"refinement_core": {
|
||||
"constructor": "iodine.modules.refinement.RefinementCore",
|
||||
"encoder_net": {
|
||||
"constructor": "iodine.modules.networks.CNN",
|
||||
"mode": "avg_pool",
|
||||
"cnn_opt": {
|
||||
"output_channels": [32, 32, 32],
|
||||
"strides": [2],
|
||||
"kernel_shapes": [5],
|
||||
"activation": "elu",
|
||||
},
|
||||
"mlp_opt": {
|
||||
"output_sizes": [128],
|
||||
"activation": "elu"
|
||||
},
|
||||
},
|
||||
"recurrent_net": {
|
||||
"constructor": "iodine.modules.networks.LSTM",
|
||||
"hidden_sizes": [], # No recurrent layer used for this dataset
|
||||
},
|
||||
"refinement_head": {
|
||||
"constructor": "iodine.modules.refinement.ResHead"
|
||||
},
|
||||
},
|
||||
"latent_dist": {
|
||||
"constructor": "iodine.modules.distributions.LocScaleDistribution",
|
||||
"dist": "normal",
|
||||
"scale_act": "softplus",
|
||||
"scale": "var",
|
||||
"name": "latent_dist",
|
||||
},
|
||||
"output_dist": {
|
||||
"constructor": "iodine.modules.distributions.MaskedMixture",
|
||||
"num_components": num_components,
|
||||
"component_dist": {
|
||||
"constructor":
|
||||
"iodine.modules.distributions.LocScaleDistribution",
|
||||
"dist":
|
||||
"logistic",
|
||||
"scale":
|
||||
"fixed",
|
||||
"scale_val":
|
||||
0.03,
|
||||
"name":
|
||||
"pixel_distribution",
|
||||
},
|
||||
},
|
||||
"factor_evaluator": {
|
||||
"constructor":
|
||||
"iodine.modules.factor_eval.FactorRegressor",
|
||||
"mapping": [
|
||||
("position", 2, "scalar"),
|
||||
("color", 3, "scalar"),
|
||||
("shape", 20, "categorical"),
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
optimizer = {
|
||||
"constructor": "tensorflow.train.AdamOptimizer",
|
||||
"learning_rate": {
|
||||
"constructor": "tensorflow.train.exponential_decay",
|
||||
"learning_rate": learn_rate,
|
||||
"global_step": {
|
||||
"constructor": "tensorflow.train.get_or_create_global_step"
|
||||
},
|
||||
"decay_steps": 1000000,
|
||||
"decay_rate": 0.1,
|
||||
},
|
||||
"beta1": 0.95,
|
||||
}
|
||||
Executable
+20
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
pushd iodine
|
||||
wget http://storage.googleapis.com/deepmind-research-iodine/iodine_checkpoints.zip
|
||||
unzip iodine_checkpoints.zip
|
||||
popd
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
|
After Width: | Height: | Size: 148 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 15 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 17 KiB |
+202
@@ -0,0 +1,202 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# pylint: disable=g-importing-member, g-multiple-import, g-import-not-at-top
|
||||
# pylint: disable=protected-access, g-bad-import-order, missing-docstring
|
||||
# pylint: disable=unused-variable, invalid-name, no-value-for-parameter
|
||||
|
||||
from copy import deepcopy
|
||||
import os.path
|
||||
import warnings
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
from sacred import Experiment, SETTINGS
|
||||
|
||||
# Ignore all tensorflow deprecation warnings
|
||||
logging._warn_preinit_stderr = 0
|
||||
warnings.filterwarnings("ignore", module=".*tensorflow.*")
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.ERROR)
|
||||
import sonnet as snt
|
||||
from sacred.stflow import LogFileWriter
|
||||
from iodine.modules import utils
|
||||
from iodine import configurations
|
||||
|
||||
SETTINGS.CONFIG.READ_ONLY_CONFIG = False
|
||||
|
||||
ex = Experiment("iodine")
|
||||
|
||||
|
||||
@ex.config
|
||||
def default_config():
|
||||
continue_run = False # set to continue experiment from an existing checkpoint
|
||||
checkpoint_dir = ("checkpoints/iodine"
|
||||
) # if continue_run is False, "_{run_id}" will be appended
|
||||
save_summaries_steps = 10
|
||||
save_checkpoint_steps = 1000
|
||||
|
||||
n_z = 64 # number of latent dimensions
|
||||
num_components = 7 # number of components (K)
|
||||
num_iters = 5
|
||||
|
||||
learn_rate = 0.001
|
||||
batch_size = 4
|
||||
stop_after_steps = int(1e6)
|
||||
|
||||
# Details for the dataset, model and optimizer are left empty here.
|
||||
# They can be found in the configurations for individual datasets,
|
||||
# which are provided in configurations.py and added as named configs.
|
||||
data = {} # Dataset details will go here
|
||||
model = {} # Model details will go here
|
||||
optimizer = {} # Optimizer details will go here
|
||||
|
||||
|
||||
ex.named_config(configurations.clevr6)
|
||||
ex.named_config(configurations.multi_dsprites)
|
||||
ex.named_config(configurations.tetrominoes)
|
||||
|
||||
|
||||
@ex.capture
|
||||
def build(identifier, _config):
|
||||
config_copy = deepcopy(_config[identifier])
|
||||
return utils.build(config_copy, identifier=identifier)
|
||||
|
||||
|
||||
def get_train_step(model, dataset, optimizer):
|
||||
loss, scalars, _ = model(dataset("train"))
|
||||
global_step = tf.train.get_or_create_global_step()
|
||||
grads = optimizer.compute_gradients(loss)
|
||||
gradients, variables = zip(*grads)
|
||||
global_norm = tf.global_norm(gradients)
|
||||
gradients, global_norm = tf.clip_by_global_norm(
|
||||
gradients, 5.0, use_norm=global_norm)
|
||||
grads = zip(gradients, variables)
|
||||
train_op = optimizer.apply_gradients(grads, global_step=global_step)
|
||||
|
||||
with tf.control_dependencies([train_op]):
|
||||
overview = model.get_overview_images(dataset("summary"))
|
||||
scalars["debug/global_grad_norm"] = global_norm
|
||||
|
||||
summaries = {
|
||||
k: tf.summary.scalar(k, v) for k, v in scalars.items()
|
||||
}
|
||||
summaries.update(
|
||||
{k: tf.summary.image(k, v) for k, v in overview.items()})
|
||||
|
||||
return tf.identity(global_step), scalars, train_op
|
||||
|
||||
|
||||
@ex.capture
|
||||
def get_checkpoint_dir(continue_run, checkpoint_dir, _run, _log):
|
||||
if continue_run:
|
||||
assert os.path.exists(checkpoint_dir)
|
||||
_log.info("Continuing run from checkpoint at {}".format(checkpoint_dir))
|
||||
return checkpoint_dir
|
||||
|
||||
run_id = _run._id
|
||||
if run_id is None: # then no observer was added that provided an _id
|
||||
if not _run.unobserved:
|
||||
_log.warning(
|
||||
"No run_id given or provided by an Observer. (Re-)using run_id=1.")
|
||||
run_id = 1
|
||||
checkpoint_dir = checkpoint_dir + "_{run_id}".format(run_id=run_id)
|
||||
_log.info(
|
||||
"Starting a new run using checkpoint dir: '{}'".format(checkpoint_dir))
|
||||
return checkpoint_dir
|
||||
|
||||
|
||||
@ex.capture
|
||||
def get_session(chkp_dir, loss, stop_after_steps, save_summaries_steps,
|
||||
save_checkpoint_steps):
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allow_growth = True
|
||||
|
||||
hooks = [
|
||||
tf.train.StopAtStepHook(last_step=stop_after_steps),
|
||||
tf.train.NanTensorHook(loss),
|
||||
]
|
||||
|
||||
return tf.train.MonitoredTrainingSession(
|
||||
hooks=hooks,
|
||||
config=config,
|
||||
checkpoint_dir=chkp_dir,
|
||||
save_summaries_steps=save_summaries_steps,
|
||||
save_checkpoint_steps=save_checkpoint_steps,
|
||||
)
|
||||
|
||||
|
||||
@ex.command(unobserved=True)
|
||||
def load_checkpoint(use_placeholder=False, session=None):
|
||||
dataset = build("data")
|
||||
model = build("model")
|
||||
if use_placeholder:
|
||||
inputs = dataset.get_placeholders()
|
||||
else:
|
||||
inputs = dataset()
|
||||
|
||||
info = model.eval(inputs)
|
||||
if session is None:
|
||||
session = tf.Session()
|
||||
saver = tf.train.Saver()
|
||||
checkpoint_dir = get_checkpoint_dir()
|
||||
checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
|
||||
saver.restore(session, checkpoint_file)
|
||||
|
||||
print('Successfully restored Checkpoint "{}"'.format(checkpoint_file))
|
||||
# print variables
|
||||
variables = tf.global_variables() + tf.local_variables()
|
||||
for row in snt.format_variables(variables, join_lines=False):
|
||||
print(row)
|
||||
|
||||
return {
|
||||
"session": session,
|
||||
"model": model,
|
||||
"info": info,
|
||||
"inputs": inputs,
|
||||
"dataset": dataset,
|
||||
}
|
||||
|
||||
|
||||
@ex.automain
|
||||
@LogFileWriter(ex)
|
||||
def main(save_summaries_steps):
|
||||
checkpoint_dir = get_checkpoint_dir()
|
||||
|
||||
dataset = build("data")
|
||||
model = build("model")
|
||||
optimizer = build("optimizer")
|
||||
gstep, train_step_exports, train_op = get_train_step(model, dataset,
|
||||
optimizer)
|
||||
|
||||
loss, ari = [], []
|
||||
with get_session(checkpoint_dir, train_step_exports["loss/total"]) as sess:
|
||||
while not sess.should_stop():
|
||||
out = sess.run({
|
||||
"step": gstep,
|
||||
"loss": train_step_exports["loss/total"],
|
||||
"ari": train_step_exports["loss/ari_nobg"],
|
||||
"train": train_op,
|
||||
})
|
||||
loss.append(out["loss"])
|
||||
ari.append(out["ari"])
|
||||
step = out["step"]
|
||||
if step % save_summaries_steps == 0:
|
||||
mean_loss = np.mean(loss)
|
||||
mean_ari = np.mean(ari)
|
||||
ex.log_scalar("loss", mean_loss, step)
|
||||
ex.log_scalar("ari", mean_ari, step)
|
||||
print("{step:>6d} Loss: {loss: >12.2f}\t\tARI-nobg:{ari: >6.2f}".format(
|
||||
step=step, loss=mean_loss, ari=mean_ari))
|
||||
loss, ari = [], []
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,264 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Data loading functionality for IODINE."""
|
||||
# pylint: disable=g-multiple-import, missing-docstring, unused-import
|
||||
import os.path
|
||||
|
||||
from iodine.modules.utils import flatten_all_but_last, ensure_3d
|
||||
from multi_object_datasets import (
|
||||
clevr_with_masks,
|
||||
multi_dsprites,
|
||||
tetrominoes,
|
||||
objects_room,
|
||||
)
|
||||
from shapeguard import ShapeGuard
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class IODINEDataset(snt.AbstractModule):
|
||||
num_true_objects = 1
|
||||
num_channels = 3
|
||||
|
||||
factors = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
batch_size,
|
||||
image_dim,
|
||||
crop_region=None,
|
||||
shuffle_buffer=1000,
|
||||
max_num_objects=None,
|
||||
min_num_objects=None,
|
||||
grayscale=False,
|
||||
name="dataset",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(name=name)
|
||||
self.path = os.path.abspath(os.path.expanduser(path))
|
||||
self.batch_size = batch_size
|
||||
self.crop_region = crop_region
|
||||
self.image_dim = image_dim
|
||||
self.shuffle_buffer = shuffle_buffer
|
||||
self.max_num_objects = max_num_objects
|
||||
self.min_num_objects = min_num_objects
|
||||
self.grayscale = grayscale
|
||||
self.dataset = None
|
||||
|
||||
def _build(self, subset="train"):
|
||||
dataset = self.dataset
|
||||
|
||||
# filter by number of objects
|
||||
if self.max_num_objects is not None or self.min_num_objects is not None:
|
||||
dataset = self.dataset.filter(self.filter_by_num_objects)
|
||||
|
||||
if subset == "train":
|
||||
# normal mode returns a shuffled dataset iterator
|
||||
if self.shuffle_buffer is not None:
|
||||
dataset = dataset.shuffle(self.shuffle_buffer)
|
||||
elif subset == "summary":
|
||||
# for generating summaries and overview images
|
||||
# returns a single fixed batch
|
||||
dataset = dataset.take(self.batch_size)
|
||||
|
||||
# repeat and batch
|
||||
dataset = dataset.repeat().batch(self.batch_size, drop_remainder=True)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
data = iterator.get_next()
|
||||
|
||||
# preprocess the data to ensure correct format, scale images etc.
|
||||
data = self.preprocess(data)
|
||||
return data
|
||||
|
||||
def filter_by_num_objects(self, d):
|
||||
if "visibility" not in d:
|
||||
return tf.constant(True)
|
||||
min_num_objects = self.max_num_objects or 0
|
||||
max_num_objects = self.max_num_objects or 6
|
||||
|
||||
min_predicate = tf.greater_equal(
|
||||
tf.reduce_sum(d["visibility"]),
|
||||
tf.constant(min_num_objects - 1e-5, dtype=tf.float32),
|
||||
)
|
||||
max_predicate = tf.less_equal(
|
||||
tf.reduce_sum(d["visibility"]),
|
||||
tf.constant(max_num_objects + 1e-5, dtype=tf.float32),
|
||||
)
|
||||
return tf.logical_and(min_predicate, max_predicate)
|
||||
|
||||
def preprocess(self, data):
|
||||
sg = ShapeGuard(dims={
|
||||
"B": self.batch_size,
|
||||
"H": self.image_dim[0],
|
||||
"W": self.image_dim[1]
|
||||
})
|
||||
image = sg.guard(data["image"], "B, h, w, C")
|
||||
mask = sg.guard(data["mask"], "B, L, h, w, 1")
|
||||
|
||||
# to float
|
||||
image = tf.cast(image, tf.float32) / 255.0
|
||||
mask = tf.cast(mask, tf.float32) / 255.0
|
||||
|
||||
# crop
|
||||
if self.crop_region is not None:
|
||||
height_slice = slice(self.crop_region[0][0], self.crop_region[0][1])
|
||||
width_slice = slice(self.crop_region[1][0], self.crop_region[1][1])
|
||||
image = image[:, height_slice, width_slice, :]
|
||||
|
||||
mask = mask[:, :, height_slice, width_slice, :]
|
||||
|
||||
flat_mask, unflatten = flatten_all_but_last(mask, n_dims=3)
|
||||
|
||||
# rescale
|
||||
size = tf.constant(
|
||||
self.image_dim, dtype=tf.int32, shape=[2], verify_shape=True)
|
||||
image = tf.image.resize_images(
|
||||
image, size, method=tf.image.ResizeMethod.BILINEAR)
|
||||
mask = tf.image.resize_images(
|
||||
flat_mask, size, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
|
||||
|
||||
if self.grayscale:
|
||||
image = tf.reduce_mean(image, axis=-1, keepdims=True)
|
||||
|
||||
output = {
|
||||
"image": sg.guard(image[:, None], "B, T, H, W, C"),
|
||||
"mask": sg.guard(unflatten(mask)[:, None], "B, T, L, H, W, 1"),
|
||||
"factors": self.preprocess_factors(data, sg),
|
||||
}
|
||||
|
||||
if "visibility" in data:
|
||||
output["visibility"] = sg.guard(data["visibility"], "B, L")
|
||||
else:
|
||||
output["visibility"] = tf.ones(sg["B, L"], dtype=tf.float32)
|
||||
|
||||
return output
|
||||
|
||||
def preprocess_factors(self, data, sg):
|
||||
return {
|
||||
name: sg.guard(ensure_3d(data[name]), "B, L, *")
|
||||
for name in self.factors
|
||||
}
|
||||
|
||||
def get_placeholders(self, batch_size=None):
|
||||
batch_size = batch_size or self.batch_size
|
||||
sg = ShapeGuard(
|
||||
dims={
|
||||
"B": batch_size,
|
||||
"H": self.image_dim[0],
|
||||
"W": self.image_dim[1],
|
||||
"L": self.num_true_objects,
|
||||
"C": 3,
|
||||
"T": 1,
|
||||
})
|
||||
return {
|
||||
"image": tf.placeholder(dtype=tf.float32, shape=sg["B, T, H, W, C"]),
|
||||
"mask": tf.placeholder(dtype=tf.float32, shape=sg["B, T, L, H, W, 1"]),
|
||||
"visibility": tf.placeholder(dtype=tf.float32, shape=sg["B, L"]),
|
||||
"factors": {
|
||||
name:
|
||||
tf.placeholder(dtype=dtype, shape=sg["B, L, {}".format(size)])
|
||||
for name, (dtype, size) in self.factors
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class CLEVR(IODINEDataset):
|
||||
num_true_objects = 11
|
||||
num_channels = 3
|
||||
factors = {
|
||||
"color": (tf.uint8, 1),
|
||||
"shape": (tf.uint8, 1),
|
||||
"size": (tf.uint8, 1),
|
||||
"position": (tf.float32, 3),
|
||||
"rotation": (tf.float32, 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
crop_region=((29, 221), (64, 256)),
|
||||
image_dim=(128, 128),
|
||||
name="clevr",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
path=path,
|
||||
crop_region=crop_region,
|
||||
image_dim=image_dim,
|
||||
name=name,
|
||||
**kwargs)
|
||||
self.dataset = clevr_with_masks.dataset(self.path)
|
||||
|
||||
def preprocess_factors(self, data, sg):
|
||||
|
||||
return {
|
||||
"color": sg.guard(ensure_3d(data["color"]), "B, L, 1"),
|
||||
"shape": sg.guard(ensure_3d(data["shape"]), "B, L, 1"),
|
||||
"size": sg.guard(ensure_3d(data["color"]), "B, L, 1"),
|
||||
"position": sg.guard(ensure_3d(data["pixel_coords"]), "B, L, 3"),
|
||||
"rotation": sg.guard(ensure_3d(data["rotation"]), "B, L, 1"),
|
||||
}
|
||||
|
||||
|
||||
class MultiDSprites(IODINEDataset):
|
||||
num_true_objects = 6
|
||||
num_channels = 3
|
||||
factors = {
|
||||
"color": (tf.float32, 3),
|
||||
"shape": (tf.uint8, 1),
|
||||
"scale": (tf.float32, 1),
|
||||
"x": (tf.float32, 1),
|
||||
"y": (tf.float32, 1),
|
||||
"orientation": (tf.float32, 1),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path,
|
||||
# variant from ['binarized', 'colored_on_grayscale', 'colored_on_colored']
|
||||
dataset_variant="colored_on_grayscale",
|
||||
image_dim=(64, 64),
|
||||
name="multi_dsprites",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(path=path, name=name, image_dim=image_dim, **kwargs)
|
||||
self.dataset_variant = dataset_variant
|
||||
self.dataset = multi_dsprites.dataset(self.path, self.dataset_variant)
|
||||
|
||||
|
||||
class Tetrominoes(IODINEDataset):
|
||||
num_true_objects = 6
|
||||
num_channels = 3
|
||||
factors = {
|
||||
"color": (tf.uint8, 3),
|
||||
"shape": (tf.uint8, 1),
|
||||
"position": (tf.float32, 2),
|
||||
}
|
||||
|
||||
def __init__(self, path, image_dim=(35, 35), name="tetrominoes", **kwargs):
|
||||
super().__init__(path=path, name=name, image_dim=image_dim, **kwargs)
|
||||
self.dataset = tetrominoes.dataset(self.path)
|
||||
|
||||
def preprocess_factors(self, data, sg):
|
||||
pos_x = ensure_3d(data["x"])
|
||||
pos_y = ensure_3d(data["y"])
|
||||
position = tf.concat([pos_x, pos_y], axis=2)
|
||||
|
||||
return {
|
||||
"color": sg.guard(ensure_3d(data["color"]), "B, L, 3"),
|
||||
"shape": sg.guard(ensure_3d(data["shape"]), "B, L, 1"),
|
||||
"position": sg.guard(ensure_3d(position), "B, L, 2"),
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Decoders for rendering images."""
|
||||
# pylint: disable=missing-docstring
|
||||
from iodine.modules.distributions import MixtureParameters
|
||||
import shapeguard
|
||||
import sonnet as snt
|
||||
|
||||
|
||||
class ComponentDecoder(snt.AbstractModule):
|
||||
|
||||
def __init__(self, pixel_decoder, name="component_decoder"):
|
||||
super().__init__(name=name)
|
||||
self._pixel_decoder = pixel_decoder
|
||||
self._sg = shapeguard.ShapeGuard()
|
||||
|
||||
def set_output_shapes(self, pixel, mask):
|
||||
self._sg.guard(pixel, "K, H, W, Cp")
|
||||
self._sg.guard(mask, "K, H, W, 1")
|
||||
self._pixel_decoder.set_output_shapes(self._sg["H, W, 1 + Cp"])
|
||||
|
||||
def _build(self, z):
|
||||
self._sg.guard(z, "B, K, Z")
|
||||
z_flat = self._sg.reshape(z, "B*K, Z")
|
||||
pixel_params = self._pixel_decoder(z_flat).params
|
||||
|
||||
self._sg.guard(pixel_params, "B*K, H, W, 1 + Cp")
|
||||
mask_params = pixel_params[Ellipsis, 0:1]
|
||||
pixel_params = pixel_params[Ellipsis, 1:]
|
||||
|
||||
output = MixtureParameters(
|
||||
pixel=self._sg.reshape(pixel_params, "B, K, H, W, Cp"),
|
||||
mask=self._sg.reshape(mask_params, "B, K, H, W, 1"),
|
||||
)
|
||||
|
||||
del self._sg.B
|
||||
return output
|
||||
@@ -0,0 +1,223 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Collection of sonnet modules that wrap useful distributions."""
|
||||
# pylint: disable=missing-docstring, g-doc-args, g-short-docstring-punctuation
|
||||
# pylint: disable=g-space-before-docstring-summary
|
||||
# pylint: disable=g-no-space-after-docstring-summary
|
||||
import collections
|
||||
from iodine.modules.utils import get_act_func
|
||||
from iodine.modules.utils import get_distribution
|
||||
import shapeguard
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow_probability as tfp
|
||||
|
||||
|
||||
tfd = tfp.distributions
|
||||
|
||||
FlatParameters = collections.namedtuple("ParameterOut", ["params"])
|
||||
MixtureParameters = collections.namedtuple("MixtureOut", ["pixel", "mask"])
|
||||
|
||||
|
||||
class DistributionModule(snt.AbstractModule):
|
||||
"""Distribution Base class supporting shape inference & default priors."""
|
||||
|
||||
def __init__(self, name="distribution"):
|
||||
super().__init__(name=name)
|
||||
self._output_shape = None
|
||||
|
||||
def set_output_shape(self, shape):
|
||||
self._output_shape = shape
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return self._output_shape
|
||||
|
||||
@property
|
||||
def input_shapes(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_default_prior(self, batch_dim=(1,)):
|
||||
return self(
|
||||
tf.zeros(list(batch_dim) + self.input_shapes.params, dtype=tf.float32))
|
||||
|
||||
|
||||
class BernoulliOutput(DistributionModule):
|
||||
|
||||
def __init__(self, name="bernoulli_output"):
|
||||
super().__init__(name=name)
|
||||
|
||||
@property
|
||||
def input_shapes(self):
|
||||
return FlatParameters(self.output_shape)
|
||||
|
||||
def _build(self, params):
|
||||
return tfd.Independent(
|
||||
tfd.Bernoulli(logits=params, dtype=tf.float32),
|
||||
reinterpreted_batch_ndims=1)
|
||||
|
||||
|
||||
class LocScaleDistribution(DistributionModule):
|
||||
"""Generic IID location / scale distribution.
|
||||
|
||||
Input parameters are concatenation of location and scale (2*Z,)
|
||||
|
||||
Args:
|
||||
dist: Distribution or str Kind of distribution used. Supports Normal,
|
||||
Logistic, Laplace, and StudentT distributions.
|
||||
dist_kwargs: dict custom keyword arguments for the distribution
|
||||
scale_act: function or str or None activation function to be applied to
|
||||
the scale input
|
||||
scale: str
|
||||
different modes for computing the scale:
|
||||
* stddev: scale is computed as scale_act(s)
|
||||
* var: scale is computed as sqrt(scale_act(s))
|
||||
* prec: scale is computed as 1./scale_act(s)
|
||||
* fixed: scale is a global variable (same for all pixels) if
|
||||
scale_val==-1. then it is a trainable variable initialized to 0.1
|
||||
else it is fixed to scale_val (input shape is only (Z,) in this
|
||||
case)
|
||||
scale_val: float determines the scale value (only used if scale=='fixed').
|
||||
loc_act: function or str or None activation function to be applied to the
|
||||
location input. Supports optional activation functions for scale and
|
||||
location.
|
||||
Supports different "modes" for scaling:
|
||||
* stddev:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dist=tfd.Normal,
|
||||
dist_kwargs=None,
|
||||
scale_act=tf.exp,
|
||||
scale="stddev",
|
||||
scale_val=1.0,
|
||||
loc_act=None,
|
||||
name="loc_scale_dist",
|
||||
):
|
||||
super().__init__(name=name)
|
||||
self._scale_act = get_act_func(scale_act)
|
||||
self._loc_act = get_act_func(loc_act)
|
||||
# supports Normal, Logstic, Laplace, StudentT
|
||||
self._dist = get_distribution(dist)
|
||||
self._dist_kwargs = dist_kwargs or {}
|
||||
|
||||
assert scale in ["stddev", "var", "prec", "fixed"], scale
|
||||
self._scale = scale
|
||||
self._scale_val = scale_val
|
||||
|
||||
@property
|
||||
def input_shapes(self):
|
||||
if self._scale == "fixed":
|
||||
param_shape = self.output_shape
|
||||
else:
|
||||
param_shape = self.output_shape[:-1] + [self.output_shape[-1] * 2]
|
||||
return FlatParameters(param_shape)
|
||||
|
||||
def _build(self, params):
|
||||
if self._scale == "fixed":
|
||||
loc = params
|
||||
scale = None # set later
|
||||
else:
|
||||
n_channels = params.get_shape().as_list()[-1]
|
||||
assert n_channels % 2 == 0
|
||||
assert n_channels // 2 == self.output_shape[-1]
|
||||
loc = params[Ellipsis, :n_channels // 2]
|
||||
scale = params[Ellipsis, n_channels // 2:]
|
||||
|
||||
# apply activation functions
|
||||
if self._scale != "fixed":
|
||||
scale = self._scale_act(scale)
|
||||
loc = self._loc_act(loc)
|
||||
|
||||
# apply the correct parametrization
|
||||
if self._scale == "var":
|
||||
scale = tf.sqrt(scale)
|
||||
elif self._scale == "prec":
|
||||
scale = tf.reciprocal(scale)
|
||||
elif self._scale == "fixed":
|
||||
if self._scale_val == -1.0:
|
||||
scale_val = tf.get_variable(
|
||||
"scale", initializer=tf.constant(0.1, dtype=tf.float32))
|
||||
else:
|
||||
scale_val = self._scale_val
|
||||
scale = tf.ones_like(loc) * scale_val
|
||||
# else 'stddev'
|
||||
|
||||
dist = self._dist(loc=loc, scale=scale, **self._dist_kwargs)
|
||||
|
||||
return tfd.Independent(dist, reinterpreted_batch_ndims=1)
|
||||
|
||||
|
||||
class MaskedMixture(DistributionModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_components,
|
||||
component_dist,
|
||||
mask_activation=None,
|
||||
name="masked_mixture",
|
||||
):
|
||||
"""
|
||||
Spatial Mixture Model composed of a categorical masking distribution and
|
||||
a custom pixel-wise component distribution (usually logistic or
|
||||
gaussian).
|
||||
|
||||
Args:
|
||||
num_components: int Number of mixture components >= 2
|
||||
component_dist: the distribution to use for the individual components
|
||||
mask_activation: str or function or None activation function that
|
||||
should be applied to the mask before the softmax.
|
||||
name: str
|
||||
"""
|
||||
|
||||
super().__init__(name=name)
|
||||
self._num_components = num_components
|
||||
self._dist = component_dist
|
||||
self._mask_activation = get_act_func(mask_activation)
|
||||
|
||||
def set_output_shape(self, shape):
|
||||
super().set_output_shape(shape)
|
||||
self._dist.set_output_shape(shape)
|
||||
|
||||
def _build(self, pixel, mask):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
# MASKING
|
||||
sg.guard(mask, "B, K, H, W, 1")
|
||||
mask = tf.transpose(mask, perm=[0, 2, 3, 4, 1])
|
||||
mask = sg.reshape(mask, "B, H, W, K")
|
||||
mask = self._mask_activation(mask)
|
||||
mask = mask[:, tf.newaxis] # add K=1 axis since K is removed by mixture
|
||||
mix_dist = tfd.Categorical(logits=mask)
|
||||
|
||||
# COMPONENTS
|
||||
sg.guard(pixel, "B, K, H, W, Cp")
|
||||
params = tf.transpose(pixel, perm=[0, 2, 3, 1, 4])
|
||||
params = params[:, tf.newaxis] # add K=1 axis since K is removed by mixture
|
||||
dist = self._dist(params)
|
||||
return tfd.MixtureSameFamily(
|
||||
mixture_distribution=mix_dist, components_distribution=dist)
|
||||
|
||||
@property
|
||||
def input_shapes(self):
|
||||
pixel = [self._num_components] + self._dist.input_shapes.params
|
||||
mask = pixel[:-1] + [1]
|
||||
return MixtureParameters(pixel, mask)
|
||||
|
||||
def get_default_prior(self, batch_dim=(1,)):
|
||||
pixel = tf.zeros(
|
||||
list(batch_dim) + self.input_shapes.pixel, dtype=tf.float32)
|
||||
mask = tf.zeros(list(batch_dim) + self.input_shapes.mask, dtype=tf.float32)
|
||||
return self(pixel, mask)
|
||||
@@ -0,0 +1,206 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Factor Evaluation Module."""
|
||||
# pylint: disable=unused-variable
|
||||
|
||||
import collections
|
||||
import functools
|
||||
from iodine.modules import utils
|
||||
import shapeguard
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
Factor = collections.namedtuple("Factor", ["name", "size", "type"])
|
||||
|
||||
|
||||
class FactorRegressor(snt.AbstractModule):
|
||||
"""Assess representations by learning a linear mapping to latents."""
|
||||
|
||||
def __init__(self, mapping=None, name="repres_content"):
|
||||
super().__init__(name=name)
|
||||
if mapping is None:
|
||||
self._mapping = [
|
||||
Factor("color", 3, "scalar"),
|
||||
Factor("shape", 4, "categorical"),
|
||||
Factor("scale", 1, "scalar"),
|
||||
Factor("x", 1, "scalar"),
|
||||
Factor("y", 1, "scalar"),
|
||||
Factor("orientation", 2, "angle"),
|
||||
]
|
||||
else:
|
||||
self._mapping = [Factor(*m) for m in mapping]
|
||||
|
||||
def _build(self, z, latent, visibility, pred_mask, true_mask):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
z = sg.guard(z, "B, K, Z")
|
||||
pred_mask = sg.guard(pred_mask, "B, K, H, W, 1")
|
||||
true_mask = sg.guard(true_mask, "B, L, H, W, 1")
|
||||
|
||||
visibility = sg.guard(visibility, "B, L")
|
||||
num_visible_obj = tf.reduce_sum(visibility)
|
||||
|
||||
# Map z to predictions for all latents
|
||||
sg.M = sum([m.size for m in self._mapping])
|
||||
self.predictor = snt.Linear(sg.M, name="predict_latents")
|
||||
z_flat = sg.reshape(z, "B*K, Z")
|
||||
all_preds = sg.guard(self.predictor(z_flat), "B*K, M")
|
||||
all_preds = sg.reshape(all_preds, "B, 1, K, M")
|
||||
all_preds = tf.tile(all_preds, sg["1, L, 1, 1"])
|
||||
|
||||
# prepare latents
|
||||
latents = {}
|
||||
mean_var_tot = {}
|
||||
for m in self._mapping:
|
||||
with tf.name_scope(m.name):
|
||||
# preprocess, reshape, and tile
|
||||
lat_preprocess = self.get_preprocessing(m)
|
||||
lat = sg.guard(
|
||||
lat_preprocess(latent[m.name]), "B, L, {}".format(m.size))
|
||||
# compute mean over latent by training a variable using mse
|
||||
if m.type in {"scalar", "angle"}:
|
||||
mvt = utils.OnlineMeanVarEstimator(
|
||||
axis=[0, 1], ddof=1, name="{}_mean_var".format(m.name))
|
||||
mean_var_tot[m.name] = mvt(lat, visibility[:, :, tf.newaxis])
|
||||
|
||||
lat = tf.reshape(lat, sg["B, L, 1"] + [-1])
|
||||
lat = tf.tile(lat, sg["1, 1, K, 1"])
|
||||
latents[m.name] = lat
|
||||
|
||||
# prepare predictions
|
||||
idx = 0
|
||||
predictions = {}
|
||||
for m in self._mapping:
|
||||
with tf.name_scope(m.name):
|
||||
assert m.name in latent, "{} not in {}".format(m.name, latent.keys())
|
||||
pred = all_preds[Ellipsis, idx:idx + m.size]
|
||||
predictions[m.name] = sg.guard(pred, "B, L, K, {}".format(m.size))
|
||||
idx += m.size
|
||||
|
||||
# compute error
|
||||
total_pairwise_errors = None
|
||||
for m in self._mapping:
|
||||
with tf.name_scope(m.name):
|
||||
error_fn = self.get_error_func(m)
|
||||
sg.guard(latents[m.name], "B, L, K, {}".format(m.size))
|
||||
sg.guard(predictions[m.name], "B, L, K, {}".format(m.size))
|
||||
err = error_fn(latents[m.name], predictions[m.name])
|
||||
sg.guard(err, "B, L, K")
|
||||
if total_pairwise_errors is None:
|
||||
total_pairwise_errors = err
|
||||
else:
|
||||
total_pairwise_errors += err
|
||||
|
||||
# determine best assignment by comparing masks
|
||||
obj_mask = true_mask[:, :, tf.newaxis]
|
||||
pred_mask = pred_mask[:, tf.newaxis]
|
||||
pairwise_overlap = tf.reduce_sum(obj_mask * pred_mask, axis=[3, 4, 5])
|
||||
best_match = sg.guard(tf.argmax(pairwise_overlap, axis=2), "B, L")
|
||||
assignment = tf.one_hot(best_match, sg.K)
|
||||
assignment *= visibility[:, :, tf.newaxis] # Mask non-visible objects
|
||||
|
||||
# total error
|
||||
total_error = (
|
||||
tf.reduce_sum(assignment * total_pairwise_errors) / num_visible_obj)
|
||||
|
||||
# compute scalars
|
||||
monitored_scalars = {}
|
||||
for m in self._mapping:
|
||||
with tf.name_scope(m.name):
|
||||
metric = self.get_metric(m)
|
||||
scalar = metric(
|
||||
latents[m.name],
|
||||
predictions[m.name],
|
||||
assignment[:, :, :, tf.newaxis],
|
||||
mean_var_tot.get(m.name),
|
||||
num_visible_obj,
|
||||
)
|
||||
monitored_scalars[m.name] = scalar
|
||||
return total_error, monitored_scalars, mean_var_tot, predictions, assignment
|
||||
|
||||
@snt.reuse_variables
|
||||
def predict(self, z):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
z = sg.guard(z, "B, Z")
|
||||
all_preds = sg.guard(self.predictor(z), "B, M")
|
||||
|
||||
idx = 0
|
||||
predictions = {}
|
||||
for m in self._mapping:
|
||||
with tf.name_scope(m.name):
|
||||
pred = all_preds[:, idx:idx + m.size]
|
||||
predictions[m.name] = sg.guard(pred, "B, {}".format(m.size))
|
||||
idx += m.size
|
||||
return predictions
|
||||
|
||||
@staticmethod
|
||||
def get_error_func(factor):
|
||||
if factor.type in {"scalar", "angle"}:
|
||||
return sse
|
||||
elif factor.type == "categorical":
|
||||
return functools.partial(
|
||||
tf.losses.softmax_cross_entropy, reduction="none")
|
||||
else:
|
||||
raise KeyError(factor.type)
|
||||
|
||||
@staticmethod
|
||||
def get_metric(factor):
|
||||
if factor.type in {"scalar", "angle"}:
|
||||
return r2
|
||||
elif factor.type == "categorical":
|
||||
return accuracy
|
||||
else:
|
||||
raise KeyError(factor.type)
|
||||
|
||||
@staticmethod
|
||||
def one_hot(f, nr_categories):
|
||||
return tf.one_hot(tf.cast(f[Ellipsis, 0], tf.int32), depth=nr_categories)
|
||||
|
||||
@staticmethod
|
||||
def angle_to_vector(theta):
|
||||
return tf.concat([tf.math.cos(theta), tf.math.sin(theta)], axis=-1)
|
||||
|
||||
@staticmethod
|
||||
def get_preprocessing(factor):
|
||||
if factor.type == "scalar":
|
||||
return tf.identity
|
||||
elif factor.type == "categorical":
|
||||
return functools.partial(
|
||||
FactorRegressor.one_hot, nr_categories=factor.size)
|
||||
elif factor.type == "angle":
|
||||
return FactorRegressor.angle_to_vector
|
||||
else:
|
||||
raise KeyError(factor.type)
|
||||
|
||||
|
||||
def sse(true, pred):
|
||||
# run our own sum squared error because we want to reduce sum over last dim
|
||||
return tf.reduce_sum(tf.square(true - pred), axis=-1)
|
||||
|
||||
|
||||
def accuracy(labels, logits, assignment, mean_var_tot, num_vis):
|
||||
del mean_var_tot # unused
|
||||
pred = tf.argmax(logits, axis=-1, output_type=tf.int32)
|
||||
labels = tf.argmax(labels, axis=-1, output_type=tf.int32)
|
||||
correct = tf.cast(tf.equal(labels, pred), tf.float32)
|
||||
return tf.reduce_sum(correct * assignment[Ellipsis, 0]) / num_vis
|
||||
|
||||
|
||||
def r2(labels, pred, assignment, mean_var_tot, num_vis):
|
||||
del num_vis # unused
|
||||
mean, var, _ = mean_var_tot
|
||||
# labels, pred: (B, L, K, n)
|
||||
ss_res = tf.reduce_sum(tf.square(labels - pred) * assignment, axis=2)
|
||||
ss_tot = var[tf.newaxis, tf.newaxis, :] # (1, 1, n)
|
||||
return tf.reduce_mean(1.0 - ss_res / ss_tot)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,300 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Network modules."""
|
||||
# pylint: disable=g-multiple-import, g-doc-args, g-short-docstring-punctuation
|
||||
# pylint: disable=g-no-space-after-docstring-summary
|
||||
from iodine.modules.distributions import FlatParameters
|
||||
from iodine.modules.utils import flatten_all_but_last, get_act_func
|
||||
import numpy as np
|
||||
import shapeguard
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class CNN(snt.AbstractModule):
|
||||
"""ConvNet2D followed by an MLP.
|
||||
|
||||
This is a typical encoder architecture for VAEs, and has been found to work
|
||||
well. One small improvement is to append coordinate channels on the input,
|
||||
though for most datasets the improvement obtained is negligible.
|
||||
"""
|
||||
|
||||
def __init__(self, cnn_opt, mlp_opt, mode="flatten", name="cnn"):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
cnn_opt: Dictionary. Kwargs for the cnn. See vae_lib.ConvNet2D for
|
||||
details.
|
||||
mlp_opt: Dictionary. Kwargs for the mlp. See vae_lib.MLP for details.
|
||||
name: String. Optional name.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
if "activation" in cnn_opt:
|
||||
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
|
||||
self._cnn_opt = cnn_opt
|
||||
|
||||
if "activation" in mlp_opt:
|
||||
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
|
||||
self._mlp_opt = mlp_opt
|
||||
|
||||
self._mode = mode
|
||||
|
||||
def set_output_shapes(self, shape):
|
||||
# assert self._mlp_opt['output_sizes'][-1] is None, self._mlp_opt
|
||||
sg = shapeguard.ShapeGuard()
|
||||
sg.guard(shape, "1, Y")
|
||||
self._mlp_opt["output_sizes"][-1] = sg.Y
|
||||
|
||||
def _build(self, image):
|
||||
"""Connect model to TensorFlow graph."""
|
||||
assert self._mlp_opt["output_sizes"][-1] is not None, "set output_shapes"
|
||||
sg = shapeguard.ShapeGuard()
|
||||
flat_image, unflatten = flatten_all_but_last(image, n_dims=3)
|
||||
sg.guard(flat_image, "B, H, W, C")
|
||||
|
||||
cnn = snt.nets.ConvNet2D(
|
||||
activate_final=True,
|
||||
paddings=("SAME",),
|
||||
normalize_final=False,
|
||||
**self._cnn_opt)
|
||||
mlp = snt.nets.MLP(**self._mlp_opt)
|
||||
|
||||
# run CNN
|
||||
net = cnn(flat_image)
|
||||
|
||||
if self._mode == "flatten":
|
||||
# flatten
|
||||
net_shape = net.get_shape().as_list()
|
||||
flat_shape = net_shape[:-3] + [np.prod(net_shape[-3:])]
|
||||
net = tf.reshape(net, flat_shape)
|
||||
elif self._mode == "avg_pool":
|
||||
net = tf.reduce_mean(net, axis=[1, 2])
|
||||
else:
|
||||
raise KeyError('Unknown mode "{}"'.format(self._mode))
|
||||
# run MLP
|
||||
output = sg.guard(mlp(net), "B, Y")
|
||||
return FlatParameters(unflatten(output))
|
||||
|
||||
|
||||
class MLP(snt.AbstractModule):
|
||||
"""MLP."""
|
||||
|
||||
def __init__(self, name="mlp", **mlp_opt):
|
||||
super().__init__(name=name)
|
||||
if "activation" in mlp_opt:
|
||||
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
|
||||
self._mlp_opt = mlp_opt
|
||||
assert mlp_opt["output_sizes"][-1] is None, mlp_opt
|
||||
|
||||
def set_output_shapes(self, shape):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
sg.guard(shape, "1, Y")
|
||||
self._mlp_opt["output_sizes"][-1] = sg.Y
|
||||
|
||||
def _build(self, data):
|
||||
"""Connect model to TensorFlow graph."""
|
||||
assert self._mlp_opt["output_sizes"][-1] is not None, "set output_shapes"
|
||||
sg = shapeguard.ShapeGuard()
|
||||
flat_data, unflatten = flatten_all_but_last(data)
|
||||
sg.guard(flat_data, "B, N")
|
||||
|
||||
mlp = snt.nets.MLP(**self._mlp_opt)
|
||||
# run MLP
|
||||
output = sg.guard(mlp(flat_data), "B, Y")
|
||||
return FlatParameters(unflatten(output))
|
||||
|
||||
|
||||
class DeConv(snt.AbstractModule):
|
||||
"""MLP followed by Deconv net.
|
||||
|
||||
This decoder is commonly used by vanilla VAE models. However, in practice
|
||||
BroadcastConv (see below) seems to disentangle slightly better.
|
||||
"""
|
||||
|
||||
def __init__(self, mlp_opt, cnn_opt, name="deconv"):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
mlp_opt: Dictionary. Kwargs for vae_lib.MLP.
|
||||
cnn_opt: Dictionary. Kwargs for vae_lib.ConvNet2D for the CNN.
|
||||
name: Optional name.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
assert cnn_opt["output_channels"][-1] is None, cnn_opt
|
||||
if "activation" in cnn_opt:
|
||||
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
|
||||
self._cnn_opt = cnn_opt
|
||||
|
||||
if mlp_opt and "activation" in mlp_opt:
|
||||
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
|
||||
self._mlp_opt = mlp_opt
|
||||
self._target_out_shape = None
|
||||
|
||||
def set_output_shapes(self, shape):
|
||||
self._target_out_shape = shape
|
||||
self._cnn_opt["output_channels"][-1] = self._target_out_shape[-1]
|
||||
|
||||
def _build(self, z):
|
||||
"""Connect model to TensorFlow graph."""
|
||||
sg = shapeguard.ShapeGuard()
|
||||
flat_z, unflatten = flatten_all_but_last(z)
|
||||
sg.guard(flat_z, "B, Z")
|
||||
sg.guard(self._target_out_shape, "H, W, C")
|
||||
mlp = snt.nets.MLP(**self._mlp_opt)
|
||||
cnn = snt.nets.ConvNet2DTranspose(
|
||||
paddings=("SAME",), normalize_final=False, **self._cnn_opt)
|
||||
net = mlp(flat_z)
|
||||
output = sg.guard(cnn(net), "B, H, W, C")
|
||||
return FlatParameters(unflatten(output))
|
||||
|
||||
|
||||
class BroadcastConv(snt.AbstractModule):
|
||||
"""MLP followed by a broadcast convolution.
|
||||
|
||||
This decoder takes a latent vector z, (optionally) applies an MLP to it,
|
||||
then tiles the resulting vector across space to have dimension [B, H, W, C]
|
||||
i.e. tiles across H and W. Then coordinate channels are appended and a
|
||||
convolutional layer is applied.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cnn_opt,
|
||||
mlp_opt=None,
|
||||
coord_type="linear",
|
||||
coord_freqs=3,
|
||||
name="broadcast_conv",
|
||||
):
|
||||
"""Args:
|
||||
cnn_opt: dict Kwargs for vae_lib.ConvNet2D for the CNN.
|
||||
mlp_opt: None or dict If dictionary, then kwargs for snt.nets.MLP. If
|
||||
None, then the model will not process the latent vector by an mlp.
|
||||
coord_type: ["linear", "cos", None] type of coordinate channels to
|
||||
add.
|
||||
None: add no coordinate channels.
|
||||
linear: two channels with values linearly spaced from -1. to 1. in
|
||||
the H and W dimension respectively.
|
||||
cos: coord_freqs^2 many channels containing cosine basis functions.
|
||||
coord_freqs: int number of frequencies used to construct the cosine
|
||||
basis functions (only for coord_type=="cos")
|
||||
name: Optional name.
|
||||
"""
|
||||
super().__init__(name=name)
|
||||
|
||||
assert cnn_opt["output_channels"][-1] is None, cnn_opt
|
||||
if "activation" in cnn_opt:
|
||||
cnn_opt["activation"] = get_act_func(cnn_opt["activation"])
|
||||
self._cnn_opt = cnn_opt
|
||||
|
||||
if mlp_opt and "activation" in mlp_opt:
|
||||
mlp_opt["activation"] = get_act_func(mlp_opt["activation"])
|
||||
self._mlp_opt = mlp_opt
|
||||
|
||||
self._target_out_shape = None
|
||||
self._coord_type = coord_type
|
||||
self._coord_freqs = coord_freqs
|
||||
|
||||
def set_output_shapes(self, shape):
|
||||
self._target_out_shape = shape
|
||||
self._cnn_opt["output_channels"][-1] = self._target_out_shape[-1]
|
||||
|
||||
def _build(self, z):
|
||||
"""Connect model to TensorFlow graph."""
|
||||
assert self._target_out_shape is not None, "Call set_output_shape"
|
||||
# reshape components into batch dimension before processing them
|
||||
sg = shapeguard.ShapeGuard()
|
||||
flat_z, unflatten = flatten_all_but_last(z)
|
||||
sg.guard(flat_z, "B, Z")
|
||||
sg.guard(self._target_out_shape, "H, W, C")
|
||||
|
||||
if self._mlp_opt is None:
|
||||
mlp = tf.identity
|
||||
else:
|
||||
mlp = snt.nets.MLP(activate_final=True, **self._mlp_opt)
|
||||
mlp_output = sg.guard(mlp(flat_z), "B, hidden")
|
||||
|
||||
# tile MLP output spatially and append coordinate channels
|
||||
broadcast_mlp_output = tf.tile(
|
||||
mlp_output[:, tf.newaxis, tf.newaxis],
|
||||
multiples=tf.constant(sg["1, H, W, 1"]),
|
||||
) # B, H, W, Z
|
||||
|
||||
dec_cnn_inputs = self.append_coordinate_channels(broadcast_mlp_output)
|
||||
|
||||
cnn = snt.nets.ConvNet2D(
|
||||
paddings=("SAME",), normalize_final=False, **self._cnn_opt)
|
||||
cnn_outputs = cnn(dec_cnn_inputs)
|
||||
sg.guard(cnn_outputs, "B, H, W, C")
|
||||
|
||||
return FlatParameters(unflatten(cnn_outputs))
|
||||
|
||||
def append_coordinate_channels(self, output):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
sg.guard(output, "B, H, W, C")
|
||||
if self._coord_type is None:
|
||||
return output
|
||||
if self._coord_type == "linear":
|
||||
w_coords = tf.linspace(-1.0, 1.0, sg.W)[None, None, :, None]
|
||||
h_coords = tf.linspace(-1.0, 1.0, sg.H)[None, :, None, None]
|
||||
w_coords = tf.tile(w_coords, sg["B, H, 1, 1"])
|
||||
h_coords = tf.tile(h_coords, sg["B, 1, W, 1"])
|
||||
return tf.concat([output, h_coords, w_coords], axis=-1)
|
||||
elif self._coord_type == "cos":
|
||||
freqs = sg.guard(tf.range(0.0, self._coord_freqs), "F")
|
||||
valx = tf.linspace(0.0, np.pi, sg.W)[None, None, :, None, None]
|
||||
valy = tf.linspace(0.0, np.pi, sg.H)[None, :, None, None, None]
|
||||
x_basis = tf.cos(valx * freqs[None, None, None, :, None])
|
||||
y_basis = tf.cos(valy * freqs[None, None, None, None, :])
|
||||
xy_basis = tf.reshape(x_basis * y_basis, sg["1, H, W, F*F"])
|
||||
coords = tf.tile(xy_basis, sg["B, 1, 1, 1"])[Ellipsis, 1:]
|
||||
return tf.concat([output, coords], axis=-1)
|
||||
else:
|
||||
raise KeyError('Unknown coord_type: "{}"'.format(self._coord_type))
|
||||
|
||||
|
||||
class LSTM(snt.RNNCore):
|
||||
"""Wrapper around snt.LSTM that supports multi-layers and runs K components in
|
||||
parallel.
|
||||
|
||||
Expects input data of shape (B, K, H) and outputs data of shape (B, K, Y)
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_sizes, name="lstm"):
|
||||
super().__init__(name=name)
|
||||
self._hidden_sizes = hidden_sizes
|
||||
with self._enter_variable_scope():
|
||||
self._lstm_layers = [snt.LSTM(hidden_size=h) for h in self._hidden_sizes]
|
||||
|
||||
def initial_state(self, batch_size, **kwargs):
|
||||
return [
|
||||
lstm.initial_state(batch_size, **kwargs) for lstm in self._lstm_layers
|
||||
]
|
||||
|
||||
def _build(self, data, prev_states):
|
||||
assert not self._hidden_sizes or self._hidden_sizes[-1] is not None
|
||||
assert len(prev_states) == len(self._hidden_sizes)
|
||||
sg = shapeguard.ShapeGuard()
|
||||
sg.guard(data, "B, K, H")
|
||||
data = sg.reshape(data, "B*K, H")
|
||||
|
||||
out = data
|
||||
new_states = []
|
||||
for lstm, pstate in zip(self._lstm_layers, prev_states):
|
||||
out, nstate = lstm(out, pstate)
|
||||
new_states.append(nstate)
|
||||
|
||||
sg.guard(out, "B*K, Y")
|
||||
out = sg.reshape(out, "B, K, Y")
|
||||
return out, new_states
|
||||
@@ -0,0 +1,226 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Plotting tools for IODINE."""
|
||||
# pylint: disable=unused-import, missing-docstring, unused-variable
|
||||
# pylint: disable=invalid-name, unexpected-keyword-arg
|
||||
import functools
|
||||
from iodine.modules.utils import get_mask_plot_colors
|
||||
from matplotlib.colors import hsv_to_rgb
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
__all__ = ("get_mask_plot_colors", "example_plot", "iterations_plot",
|
||||
"inputs_plot")
|
||||
|
||||
|
||||
def clean_ax(ax, color=None, lw=4.0):
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
if color is not None:
|
||||
for spine in ax.spines.values():
|
||||
spine.set_linewidth(lw)
|
||||
spine.set_color(color)
|
||||
|
||||
|
||||
def optional_ax(fn):
|
||||
|
||||
def _wrapped(*args, **kwargs):
|
||||
if kwargs.get("ax", None) is None:
|
||||
figsize = kwargs.pop("figsize", (4, 4))
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
kwargs["ax"] = ax
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
def optional_clean_ax(fn):
|
||||
|
||||
def _wrapped(*args, **kwargs):
|
||||
if kwargs.get("ax", None) is None:
|
||||
figsize = kwargs.pop("figsize", (4, 4))
|
||||
fig, ax = plt.subplots(figsize=figsize)
|
||||
kwargs["ax"] = ax
|
||||
color = kwargs.pop("color", None)
|
||||
lw = kwargs.pop("lw", 4.0)
|
||||
res = fn(*args, **kwargs)
|
||||
clean_ax(kwargs["ax"], color, lw)
|
||||
return res
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
@optional_clean_ax
|
||||
def show_img(img, mask=None, ax=None, norm=False):
|
||||
if norm:
|
||||
vmin, vmax = np.min(img), np.max(img)
|
||||
img = (img - vmin) / (vmax - vmin)
|
||||
if mask is not None:
|
||||
img = img * mask + np.ones_like(img) * (1.0 - mask)
|
||||
|
||||
return ax.imshow(img.clip(0.0, 1.0), interpolation="nearest")
|
||||
|
||||
|
||||
@optional_clean_ax
|
||||
def show_mask(m, ax):
|
||||
color_conv = get_mask_plot_colors(m.shape[0])
|
||||
color_mask = np.dot(np.transpose(m, [1, 2, 0]), color_conv)
|
||||
return ax.imshow(color_mask.clip(0.0, 1.0), interpolation="nearest")
|
||||
|
||||
|
||||
@optional_clean_ax
|
||||
def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"):
|
||||
return ax.matshow(
|
||||
m[Ellipsis, 0], cmap=cmap, vmin=vmin, vmax=vmax, interpolation="nearest")
|
||||
|
||||
|
||||
@optional_clean_ax
|
||||
def show_coords(m, ax):
|
||||
vmin, vmax = np.min(m), np.max(m)
|
||||
m = (m - vmin) / (vmax - vmin)
|
||||
color_conv = get_mask_plot_colors(m.shape[-1])
|
||||
color_mask = np.dot(m, color_conv)
|
||||
return ax.imshow(color_mask, interpolation="nearest")
|
||||
|
||||
|
||||
def example_plot(rinfo,
|
||||
b=0,
|
||||
t=-1,
|
||||
mask_components=False,
|
||||
size=2,
|
||||
column_titles=True):
|
||||
image = rinfo["data"]["image"][b, 0]
|
||||
recons = rinfo["outputs"]["recons"][b, t, 0]
|
||||
pred_mask = rinfo["outputs"]["pred_mask"][b, t]
|
||||
components = rinfo["outputs"]["components"][b, t]
|
||||
|
||||
K, H, W, C = components.shape
|
||||
colors = get_mask_plot_colors(K)
|
||||
|
||||
nrows = 1
|
||||
ncols = 3 + K
|
||||
fig, axes = plt.subplots(ncols=ncols, figsize=(ncols * size, nrows * size))
|
||||
|
||||
show_img(image, ax=axes[0], color="#000000")
|
||||
show_img(recons, ax=axes[1], color="#000000")
|
||||
show_mask(pred_mask[Ellipsis, 0], ax=axes[2], color="#000000")
|
||||
for k in range(K):
|
||||
mask = pred_mask[k] if mask_components else None
|
||||
show_img(components[k], ax=axes[k + 3], color=colors[k], mask=mask)
|
||||
|
||||
if column_titles:
|
||||
labels = ["Image", "Recons.", "Mask"
|
||||
] + ["Component {}".format(k + 1) for k in range(K)]
|
||||
for ax, title in zip(axes, labels):
|
||||
ax.set_title(title)
|
||||
plt.subplots_adjust(hspace=0.03, wspace=0.035)
|
||||
return fig
|
||||
|
||||
|
||||
def iterations_plot(rinfo, b=0, mask_components=False, size=2):
|
||||
image = rinfo["data"]["image"][b]
|
||||
true_mask = rinfo["data"]["true_mask"][b]
|
||||
recons = rinfo["outputs"]["recons"][b]
|
||||
pred_mask = rinfo["outputs"]["pred_mask"][b]
|
||||
pred_mask_logits = rinfo["outputs"]["pred_mask_logits"][b]
|
||||
components = rinfo["outputs"]["components"][b]
|
||||
|
||||
T, K, H, W, C = components.shape
|
||||
colors = get_mask_plot_colors(K)
|
||||
nrows = T + 1
|
||||
ncols = 2 + K
|
||||
fig, axes = plt.subplots(
|
||||
nrows=nrows, ncols=ncols, figsize=(ncols * size, nrows * size))
|
||||
for t in range(T):
|
||||
show_img(recons[t, 0], ax=axes[t, 0])
|
||||
show_mask(pred_mask[t, Ellipsis, 0], ax=axes[t, 1])
|
||||
axes[t, 0].set_ylabel("iter {}".format(t))
|
||||
for k in range(K):
|
||||
mask = pred_mask[t, k] if mask_components else None
|
||||
show_img(components[t, k], ax=axes[t, k + 2], color=colors[k], mask=mask)
|
||||
|
||||
axes[0, 0].set_title("Reconstruction")
|
||||
axes[0, 1].set_title("Mask")
|
||||
show_img(image[0], ax=axes[T, 0])
|
||||
show_mask(true_mask[0, Ellipsis, 0], ax=axes[T, 1])
|
||||
vmin = np.min(pred_mask_logits[T - 1])
|
||||
vmax = np.max(pred_mask_logits[T - 1])
|
||||
|
||||
for k in range(K):
|
||||
axes[0, k + 2].set_title("Component {}".format(k + 1)) # , color=colors[k])
|
||||
show_mat(
|
||||
pred_mask_logits[T - 1, k], ax=axes[T, k + 2], vmin=vmin, vmax=vmax)
|
||||
axes[T, k + 2].set_xlabel(
|
||||
"Mask Logits for\nComponent {}".format(k + 1)) # , color=colors[k])
|
||||
axes[T, 0].set_xlabel("Input Image")
|
||||
axes[T, 1].set_xlabel("Ground Truth Mask")
|
||||
|
||||
plt.subplots_adjust(wspace=0.05, hspace=0.05)
|
||||
return fig
|
||||
|
||||
|
||||
def inputs_plot(rinfo, b=0, t=0, size=2):
|
||||
B, T, K, H, W, C = rinfo["outputs"]["components"].shape
|
||||
colors = get_mask_plot_colors(K)
|
||||
inputs = rinfo["inputs"]["spatial"]
|
||||
rows = [
|
||||
("image", show_img, False),
|
||||
("components", show_img, False),
|
||||
("dcomponents", functools.partial(show_img, norm=True), False),
|
||||
("mask", show_mat, True),
|
||||
("pred_mask", show_mat, True),
|
||||
("dmask", functools.partial(show_mat, cmap="coolwarm"), True),
|
||||
("posterior", show_mat, True),
|
||||
("log_prob", show_mat, True),
|
||||
("counterfactual", show_mat, True),
|
||||
("coordinates", show_coords, False),
|
||||
]
|
||||
rows = [(n, f, mcb) for n, f, mcb in rows if n in inputs]
|
||||
nrows = len(rows)
|
||||
ncols = K + 1
|
||||
|
||||
fig, axes = plt.subplots(
|
||||
nrows=nrows,
|
||||
ncols=ncols,
|
||||
figsize=(ncols * size - size * 0.9, nrows * size),
|
||||
gridspec_kw={"width_ratios": [1] * K + [0.1]},
|
||||
)
|
||||
for r, (name, plot_fn, make_cbar) in enumerate(rows):
|
||||
axes[r, 0].set_ylabel(name)
|
||||
if make_cbar:
|
||||
vmin = np.min(inputs[name][b, t])
|
||||
vmax = np.max(inputs[name][b, t])
|
||||
if np.abs(vmin - vmax) < 1e-6:
|
||||
vmin -= 0.1
|
||||
vmax += 0.1
|
||||
plot_fn = functools.partial(plot_fn, vmin=vmin, vmax=vmax)
|
||||
# print("range of {:<16}: [{:0.2f}, {:0.2f}]".format(name, vmin, vmax))
|
||||
for k in range(K):
|
||||
if inputs[name].shape[2] == 1:
|
||||
m = inputs[name][b, t, 0]
|
||||
color = (0.0, 0.0, 0.0)
|
||||
else:
|
||||
m = inputs[name][b, t, k]
|
||||
color = colors[k]
|
||||
mappable = plot_fn(m, ax=axes[r, k], color=color)
|
||||
if make_cbar:
|
||||
fig.colorbar(mappable, cax=axes[r, K])
|
||||
else:
|
||||
axes[r, K].set_visible(False)
|
||||
for k in range(K):
|
||||
axes[0, k].set_title("Component {}".format(k + 1)) # , color=colors[k])
|
||||
|
||||
plt.subplots_adjust(hspace=0.05, wspace=0.05)
|
||||
return fig
|
||||
@@ -0,0 +1,163 @@
|
||||
# Lint as: python3
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Iterative refinement modules."""
|
||||
# pylint: disable=g-doc-bad-indent, unused-variable
|
||||
from iodine.modules import utils
|
||||
import shapeguard
|
||||
import sonnet as snt
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
|
||||
class RefinementCore(snt.RNNCore):
|
||||
"""Recurrent Refinement Module.
|
||||
|
||||
Refinement modules take as inputs:
|
||||
* previous state (which could be an arbitrary nested structure)
|
||||
* current inputs which include
|
||||
* image-space inputs like pixel-based errors, or mask-posteriors
|
||||
* latent-space inputs like the previous z_dist, or dz
|
||||
|
||||
They use these inputs to produce:
|
||||
* output (usually a new z_dist)
|
||||
* new_state
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
encoder_net,
|
||||
recurrent_net,
|
||||
refinement_head,
|
||||
name="refinement"):
|
||||
super().__init__(name=name)
|
||||
self._encoder_net = encoder_net
|
||||
self._recurrent_net = recurrent_net
|
||||
self._refinement_head = refinement_head
|
||||
self._sg = shapeguard.ShapeGuard()
|
||||
|
||||
def initial_state(self, batch_size, **unused_kwargs):
|
||||
return self._recurrent_net.initial_state(batch_size)
|
||||
|
||||
def _build(self, inputs, prev_state):
|
||||
sg = self._sg
|
||||
assert "spatial" in inputs, inputs.keys()
|
||||
assert "flat" in inputs, inputs.keys()
|
||||
assert "zp" in inputs["flat"], inputs["flat"].keys()
|
||||
zp = sg.guard(inputs["flat"]["zp"], "B, K, Zp")
|
||||
|
||||
x = sg.guard(self.prepare_spatial_inputs(inputs["spatial"]), "B*K, H, W, C")
|
||||
h1 = sg.guard(self._encoder_net(x).params, "B*K, H1")
|
||||
h2 = sg.guard(self.prepare_flat_inputs(h1, inputs["flat"]), "B*K, H2")
|
||||
h2_unflattened = sg.reshape(h2, "B, K, H2")
|
||||
h3, next_state = self._recurrent_net(h2_unflattened, prev_state)
|
||||
sg.guard(h3, "B, K, H3")
|
||||
outputs = sg.guard(self._refinement_head(zp, h3), "B, K, Y")
|
||||
|
||||
del self._sg.B
|
||||
return outputs, next_state
|
||||
|
||||
def prepare_spatial_inputs(self, inputs):
|
||||
values = []
|
||||
for name, val in sorted(inputs.items(), key=lambda it: it[0]):
|
||||
if val.shape.as_list()[1] == 1:
|
||||
self._sg.guard(val, "B, 1, H, W, _C")
|
||||
val = tf.tile(val, self._sg["1, K, 1, 1, 1"])
|
||||
else:
|
||||
self._sg.guard(val, "B, K, H, W, _C")
|
||||
values.append(val)
|
||||
concat_inputs = self._sg.guard(tf.concat(values, axis=-1), "B, K, H, W, C")
|
||||
return self._sg.reshape(concat_inputs, "B*K, H, W, C")
|
||||
|
||||
def prepare_flat_inputs(self, hidden, inputs):
|
||||
values = [self._sg.guard(hidden, "B*K, H1")]
|
||||
|
||||
for name, val in sorted(inputs.items(), key=lambda it: it[0]):
|
||||
self._sg.guard(val, "B, K, _")
|
||||
val_flat = tf.reshape(val, self._sg["B*K"] + [-1])
|
||||
values.append(val_flat)
|
||||
return tf.concat(values, axis=-1)
|
||||
|
||||
|
||||
class ResHead(snt.AbstractModule):
|
||||
"""Updates Zp using a residual mechanism."""
|
||||
|
||||
def __init__(self, name="residual_head"):
|
||||
super().__init__(name=name)
|
||||
|
||||
def _build(self, zp_old, inputs):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
sg.guard(zp_old, "B, K, Zp")
|
||||
sg.guard(inputs, "B, K, H")
|
||||
update = snt.Linear(sg.Zp)
|
||||
|
||||
flat_zp = sg.reshape(zp_old, "B*K, Zp")
|
||||
flat_inputs = sg.reshape(inputs, "B*K, H")
|
||||
|
||||
zp = flat_zp + update(flat_inputs)
|
||||
|
||||
return sg.reshape(zp, "B, K, Zp")
|
||||
|
||||
|
||||
class PredictorCorrectorHead(snt.AbstractModule):
|
||||
"""This refinement head is used for sequential data.
|
||||
|
||||
At every step it computes a prediction from the λ of the previous timestep
|
||||
and an update from the refinement network of the current timestep.
|
||||
|
||||
The next step λ' is computed as a gated combination of both:
|
||||
λ' = g * λ_corr + (1-g) * λ_pred
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_sizes=(64,),
|
||||
pred_gate_bias=0.0,
|
||||
corrector_gate_bias=0.0,
|
||||
activation=tf.nn.elu,
|
||||
name="predcorr_head",
|
||||
):
|
||||
super().__init__(name=name)
|
||||
self._hidden_sizes = hidden_sizes
|
||||
self._activation = utils.get_act_func(activation)
|
||||
self._pred_gate_bias = pred_gate_bias
|
||||
self._corrector_gate_bias = corrector_gate_bias
|
||||
|
||||
def _build(self, zp_old, inputs):
|
||||
sg = shapeguard.ShapeGuard()
|
||||
sg.guard(zp_old, "B, K, Zp")
|
||||
sg.guard(inputs, "B, K, H")
|
||||
update = snt.Linear(sg.Zp)
|
||||
update_gate = snt.Linear(sg.Zp)
|
||||
predict = snt.nets.MLP(
|
||||
output_sizes=list(self._hidden_sizes) + [sg.Zp * 2],
|
||||
activation=self._activation,
|
||||
)
|
||||
|
||||
flat_zp = sg.reshape(zp_old, "B*K, Zp")
|
||||
flat_inputs = sg.reshape(inputs, "B*K, H")
|
||||
|
||||
g = tf.nn.sigmoid(update_gate(flat_inputs) + self._corrector_gate_bias)
|
||||
u = update(flat_inputs)
|
||||
|
||||
# a slightly more efficient way of computing the gated update
|
||||
# (1-g) * flat_zp + g * u
|
||||
zp_corrected = flat_zp + g * (u - flat_zp)
|
||||
|
||||
predicted = predict(flat_zp)
|
||||
pred_up = predicted[:, :sg.Zp]
|
||||
pred_gate = tf.nn.sigmoid(predicted[:, sg.Zp:] + self._pred_gate_bias)
|
||||
|
||||
zp = zp_corrected + pred_gate * (pred_up - zp_corrected)
|
||||
|
||||
return sg.reshape(zp, "B, K, Zp")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,9 @@
|
||||
tensorflow-gpu==1.14.0
|
||||
tensorflow-probability==0.7.0
|
||||
dm-sonnet==1.35
|
||||
sacred>=0.7,<0.8
|
||||
shapeguard
|
||||
seaborn
|
||||
pymongo
|
||||
jupyterlab
|
||||
git+git://github.com/deepmind/multi_object_datasets.git
|
||||
Executable
+36
@@ -0,0 +1,36 @@
|
||||
#!/bin/sh
|
||||
# Copyright 2019 Deepmind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
set -e
|
||||
|
||||
echo "downloading checkpoints from GCP"
|
||||
iodine/download_checkpoints.sh
|
||||
|
||||
python3 -m venv iodine_venv
|
||||
source iodine_venv/bin/activate
|
||||
pip3 install --upgrade setuptools wheel
|
||||
pip3 install -r iodine/requirements.txt
|
||||
|
||||
# Get some fake data and put it where the real multi_objects_dataset files live.
|
||||
mkdir -p iodine/multi_object_datasets
|
||||
cp iodine/test_data/tetrominoes_mini.tfrecords iodine/multi_object_datasets/tetrominoes_train.tfrecords
|
||||
|
||||
# Run training with a cut down size.
|
||||
python3 -m iodine.main \
|
||||
-f with tetrominoes \
|
||||
data.shuffle_buffer=2 \
|
||||
data.batch_size=2 \
|
||||
n_z=4 \
|
||||
num_components=3 \
|
||||
stop_after_steps=11
|
||||
Binary file not shown.
Reference in New Issue
Block a user