Files
deepmind-research/ode_gan/odegan_mog16.ipynb
T
Chongli Qin 555df6b45e Add citation in README for ode-gan.
PiperOrigin-RevId: 343842179
2020-11-25 10:16:14 +00:00

354 lines
14 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "OYWMcJafmrfI"
},
"source": [
"Copyright 2020 DeepMind Technologies Limited.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
"\n",
"https://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yAHjf0hcm8Az"
},
"source": [
"# **This code implements ODE-GAN for Mixture of Gaussians.**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "n8p0WAstrhUT"
},
"outputs": [],
"source": [
"#@title Imports\n",
"!pip install dm-haiku\n",
"import jax\n",
"from jax import lax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import haiku as hk\n",
"import scipy as sp\n",
"import functools"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "aoIaRyCysZEs"
},
"outputs": [],
"source": [
"#@title An MLP Haiku Module \n",
"\n",
"class MLP(hk.Module):\n",
" def __init__(self, depth, hidden_size, out_dim, name='SimpleNet'):\n",
" super(MLP, self).__init__(name=name)\n",
" self._depth = depth\n",
" self._hidden_size = hidden_size\n",
" self._out_dim = out_dim\n",
" layers = []\n",
" for i in range(self._depth):\n",
" layers.append(hk.Linear(self._hidden_size, name='linear_%d'%(i)))\n",
" self._layers = layers\n",
" self._final_layer = hk.Linear(self._out_dim, name='final_layer')\n",
"\n",
" def __call__(self, input):\n",
" h = input\n",
" for i in range(self._depth):\n",
" h = jax.nn.relu(self._layers[i](h))\n",
" return self._final_layer(h)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "KBgWwKKyv6VI"
},
"outputs": [],
"source": [
"#@title Real Data\n",
"def real_data(batch_size):\n",
" mog_mean = np.array([\n",
" [ 1.50, 1.50],\n",
" [ 1.50, 0.50],\n",
" [ 1.50, -0.50],\n",
" [ 1.50, -1.50],\n",
" [ 0.50, 1.50],\n",
" [ 0.50, 0.50],\n",
" [ 0.50, -0.50],\n",
" [ 0.50, -1.50],\n",
" [-1.50, 1.50],\n",
" [-1.50, 0.50],\n",
" [-1.50, -0.50],\n",
" [-1.50, -1.50],\n",
" [-0.50, 1.50],\n",
" [-0.50, 0.50],\n",
" [-0.50, -0.50],\n",
" [-0.50, -1.50],\n",
" ])\n",
" temp = np.tile(mog_mean, (batch_size // 16 + 1,1))\n",
" mus = temp[0:batch_size,:]\n",
" return mus + 0.02 * np.random.normal(size=(batch_size, 2))\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "E6uViIllRDlL"
},
"outputs": [],
"source": [
"#@title ODE-integrators\n",
"def euler_step(func, y0, f0, t0, dt):\n",
" # Euler update\n",
" y1 = jax.tree_multimap(lambda u, v: dt * v + u, y0, f0)\n",
" return y1\n",
"\n",
"def euler_heun_step(func, y0, f0, t0, dt):\n",
" # RK2 Butcher tableaux\n",
" alpha = jnp.array([1. / 2., 0.])\n",
" beta = jnp.array([\n",
" [1 / 2, 0,],\n",
" ])\n",
" c_sol = jnp.array([1 / 2, 1 / 2])\n",
"\n",
" def body_fun(i, k):\n",
" ti = t0 + dt * alpha[i-1]\n",
" yi = jax.tree_multimap(lambda u, v: u + dt * jnp.tensordot(beta[i-1, :], v, axes=1), y0, k)\n",
" ft = func(yi, ti)\n",
" return jax.tree_multimap(lambda x, y: x.at[i, :].set(y), k, ft)\n",
" k = jax.tree_map(lambda f: jnp.zeros((2,) + f.shape,\n",
" f.dtype).at[0, :].set(f), f0)\n",
" k = lax.fori_loop(1, 2, body_fun, k)\n",
"\n",
" y1 = jax.tree_multimap(lambda u, v: dt * jnp.tensordot(c_sol, v, axes=1) + u, y0, k)\n",
" return y1\n",
"\n",
"def runge_kutta_step(func, y0, f0, t0, dt):\n",
" # RK4 Butcher tableaux\n",
" alpha = jnp.array([1. / 2., 1. / 2., 1., 0])\n",
" beta = jnp.array([\n",
" [1. / 2., 0, 0, 0],\n",
" [0, 1. / 2., 0, 0],\n",
" [0, 0, 1., 0],\n",
" ])\n",
" c_sol = jnp.array([1. / 6., 1. / 3., 1. / 3., 1. / 6.])\n",
"\n",
" def body_fun(i, k):\n",
" ti = t0 + dt * alpha[i-1]\n",
" yi = jax.tree_multimap(lambda u, v: u + dt * jnp.tensordot(beta[i-1, :], v, axes=1), y0, k)\n",
" ft = func(yi, ti)\n",
" return jax.tree_multimap(lambda x, y: x.at[i, :].set(y), k, ft)\n",
" k = jax.tree_map(lambda f: jnp.zeros((4,) + f.shape,\n",
" f.dtype).at[0, :].set(f), f0)\n",
" k = lax.fori_loop(1, 4, body_fun, k)\n",
"\n",
" y1 = jax.tree_multimap(lambda u, v: dt * jnp.tensordot(c_sol, v, axes=1) + u, y0, k)\n",
" return y1"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "NHCYH1tnwaTL"
},
"outputs": [],
"source": [
"#@title Utility Functions.\n",
"def disc_loss(disc_params, gen_params, real_examples, latents):\n",
" fake_examples = gen_model.apply(gen_params, None, latents)\n",
" real_logits = disc_model.apply(disc_params, None, real_examples)\n",
" fake_logits = disc_model.apply(disc_params, None, fake_examples)\n",
" disc_real = real_logits - jax.nn.log_sigmoid(real_logits)\n",
" disc_fake = - jax.nn.log_sigmoid(fake_logits)\n",
" return - jnp.mean(disc_real + disc_fake)\n",
"\n",
"def gen_loss(disc_params, gen_params, real_examples, latents):\n",
" fake_examples = gen_model.apply(gen_params, None, latents)\n",
" fake_logits = disc_model.apply(disc_params, None, fake_examples)\n",
" disc_fake = fake_logits - jax.nn.log_sigmoid(fake_logits)\n",
" return - jnp.mean(disc_fake)\n",
"\n",
"def gen_norm(disc_params, gen_params, real_examples, latents):\n",
" grad = jax.grad(gen_loss, argnums=1)(\n",
" disc_params, gen_params, real_examples, latents)\n",
" flat, _ = jax.tree_flatten(grad)\n",
" norm = 0.\n",
" for a in flat:\n",
" norm += jnp.sum(a * a)\n",
" return - norm\n",
"\n",
"def get_gen_grad(gen_params, t, disc_params, real_examples, latents):\n",
" return jax.grad(gen_loss, argnums=1)(\n",
" disc_params, gen_params, real_examples, latents)\n",
"\n",
"def get_disc_grad(disc_params, t, gen_params, real_examples, latents):\n",
" return jax.grad(disc_loss, argnums=0)(\n",
" disc_params, gen_params, real_examples, latents)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "xjTBhJuOh_wO"
},
"outputs": [],
"source": [
"#@title Visualising the data.\n",
"\n",
"def kde(mu, tau, contours=None, bbox=None, xlabel=\"\", ylabel=\"\", cmap='Blues', st=0):\n",
" values = np.vstack([mu, tau])\n",
" kernel = sp.stats.gaussian_kde(values)\n",
"\n",
" fig, ax = plt.subplots()\n",
" ax.axis(bbox)\n",
" ax.set_aspect(abs(bbox[1]-bbox[0])/abs(bbox[3]-bbox[2]))\n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
"\n",
" xx, yy = np.mgrid[bbox[0]:bbox[1]:300j, bbox[2]:bbox[3]:300j]\n",
" positions = np.vstack([xx.ravel(), yy.ravel()])\n",
" \n",
" f = np.reshape(kernel(positions).T, xx.shape)\n",
" cfset = ax.contourf(xx, yy, f, cmap=cmap)\n",
" if contours is not None:\n",
" x = np.arange(-2., 2., 0.1)\n",
" y = np.arange(-2., 2., 0.1)\n",
" cx, cy = np.meshgrid(x, y)\n",
" new_set = ax.contour(cx, cy, contours.squeeze().reshape(cx.shape), levels=20, colors='k', linewidths=0.8, alpha=0.5) \n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "G2G32j5N1psa"
},
"outputs": [],
"source": [
"#@title Integration\n",
"n_itrs = 30001 #@param {type : 'integer'}\n",
"n_save = 2000 #@param {type : 'integer'}\n",
"latent_size = 32 #@param {type : 'integer'}\n",
"bs = 512 #@param {type : 'integer'}\n",
"odeint = 'runge_kutta_step' #@param ['euler_step', 'euler_heun_step', 'runge_kutta_step'] {type : 'string'}\n",
"delta_t = 0.10 #@param {type : 'number'}\n",
"reg_param = 0.07 #@param {type : 'number'}\n",
"\n",
"def forward_disc(batch):\n",
" disc_model = MLP(2, 25, 1)\n",
" return disc_model(batch)\n",
"\n",
"def forward_gen(batch):\n",
" gen_model = MLP(2, 25, 2)\n",
" return gen_model(batch)\n",
"\n",
"\n",
"disc_model = hk.transform(forward_disc)\n",
"gen_model = hk.transform(forward_gen)\n",
"real_examples = real_data(bs)\n",
"\n",
"ODEINT = {'runge_kutta_step': runge_kutta_step, \n",
" 'euler_heun_step': euler_heun_step,\n",
" 'euler_step': euler_step}\n",
"@jax.jit\n",
"def ode_update(i, disc_params, gen_params, real_examples, latents):\n",
" dloss, disc_grad = jax.value_and_grad(disc_loss, argnums=0)(\n",
" disc_params, gen_params, real_examples, latents)\n",
" gloss, gen_grad = jax.value_and_grad(gen_loss, argnums=1)(\n",
" disc_params, gen_params, real_examples, latents)\n",
" disc_gen_grad = jax.grad(gen_norm, argnums=0)(\n",
" disc_params, gen_params, real_examples, latents)\n",
" grad_disc_fn = functools.partial(get_disc_grad,\n",
" **{'gen_params' : gen_params,\n",
" 'real_examples' : real_examples,\n",
" 'latents' : latents})\n",
" grad_gen_fn = functools.partial(get_gen_grad, \n",
" **{'disc_params' : disc_params,\n",
" 'real_examples' : real_examples,\n",
" 'latents' : latents}) \n",
" new_gen_params = ODEINT[odeint](\n",
" grad_gen_fn, gen_params, gen_grad, 0., delta_t)\n",
" new_disc_params = ODEINT[odeint](\n",
" grad_disc_fn, disc_params, disc_grad, 0., delta_t)\n",
" new_disc_params = jax.tree_multimap(\n",
" lambda x, y: x + delta_t * reg_param * y, new_disc_params, disc_gen_grad)\n",
" return new_disc_params, new_gen_params, -dloss, -gloss\n",
"\n",
"rng = jax.random.PRNGKey(np.random.randint(low=0, high=int(1e7)))\n",
"test_latents = np.random.normal(size=(bs * 10, latent_size))\n",
"latents = np.random.normal(size=(bs, latent_size))\n",
"disc_params = disc_model.init(rng, real_examples)\n",
"gen_params = gen_model.init(\n",
" jax.random.PRNGKey(np.random.randint(low=0, high=int(1e7))), latents)\n",
"\n",
"x = np.arange(-2., 2., 0.1)\n",
"y = np.arange(-2., 2., 0.1)\n",
"X, Y = np.meshgrid(x, y)\n",
"pairs = np.stack((X, Y), axis=-1)\n",
"pairs = np.reshape(pairs, (-1, 2))\n",
"\n",
"for e in range(n_itrs):\n",
" real_examples = real_data(bs)\n",
" latents = np.random.normal(size=(bs, latent_size))\n",
"\n",
" (disc_params, gen_params,\n",
" dloss, gloss) = ode_update(e, disc_params, gen_params, real_examples, latents)\n",
" \n",
" if e % n_save == 0:\n",
" real_logits = disc_model.apply(disc_params, None, pairs)\n",
" disc_contour = - real_logits + jax.nn.log_sigmoid(real_logits)\n",
" print('i = %d, discriminant loss = %s, generator loss = %s' %\n",
" (e, dloss, gloss))\n",
" bbox = [-2, 2, -2, 2]\n",
" fake_examples = gen_model.apply(gen_params, None, test_latents)\n",
" kde(fake_examples[:, 0], fake_examples[:, 1], contours=disc_contour, bbox=bbox, st=e)\n",
" disc_error = 0\n",
" gen_error = 0"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "odegan_mog16.ipynb"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}