mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-02-05 19:26:22 +08:00
484 lines
16 KiB
Plaintext
484 lines
16 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "V7L7Byz8qQZd"
|
||
},
|
||
"source": [
|
||
"Copyright 2021 The Powerpropagation Authors. All rights reserved\n",
|
||
"\n",
|
||
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
|
||
"\n",
|
||
"https://www.apache.org/licenses/LICENSE-2.0\n",
|
||
"\n",
|
||
"Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"cellView": "form",
|
||
"id": "SuSveOpsxQuA"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@title Imports\n",
|
||
"\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import numpy as np\n",
|
||
"import seaborn as sns\n",
|
||
"\n",
|
||
"import tensorflow as tf\n",
|
||
"import tensorflow_probability as tfp\n",
|
||
"\n",
|
||
"#@title Installing and Importing Dependencies\n",
|
||
"\n",
|
||
"print('Installing necessary libraries...')\n",
|
||
"\n",
|
||
"def install_libraries():\n",
|
||
" !pip install dm-sonnet\n",
|
||
"\n",
|
||
"import IPython\n",
|
||
"\n",
|
||
"with IPython.utils.io.capture_output() as captured:\n",
|
||
" install_libraries()\n",
|
||
"\n",
|
||
"import sonnet as snt"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "zNZE9eU_poe0"
|
||
},
|
||
"source": [
|
||
"### Implementation"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "pQBoLdNFL7fE"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@title Training \u0026 Pruning functions\n",
|
||
"\n",
|
||
"@tf.function\n",
|
||
"def train_fn(model, train_x, train_y, optimizer_to_use):\n",
|
||
" with tf.GradientTape() as tape:\n",
|
||
" loss, stats = model.loss(train_x, train_y)\n",
|
||
"\n",
|
||
" train_vars = model.trainable_variables()\n",
|
||
" train_grads = tape.gradient(loss, train_vars)\n",
|
||
" optimizer_to_use.apply(train_grads, train_vars)\n",
|
||
"\n",
|
||
" return stats\n",
|
||
"\n",
|
||
"@tf.function\n",
|
||
"def eval_fn(model, eval_x, eval_y):\n",
|
||
" _, stats = model.loss(eval_x, eval_y)\n",
|
||
"\n",
|
||
" return stats\n",
|
||
"\n",
|
||
"def _bottom_k_mask(percent_to_keep, condition):\n",
|
||
" how_many = int(percent_to_keep * condition.size)\n",
|
||
" top_k = tf.nn.top_k(condition, k=how_many)\n",
|
||
"\n",
|
||
" mask = np.zeros(shape=condition.shape, dtype=np.float32)\n",
|
||
" mask[top_k.indices.numpy()] = 1\n",
|
||
"\n",
|
||
" assert np.sum(mask) == how_many\n",
|
||
"\n",
|
||
" return mask\n",
|
||
"\n",
|
||
"def prune_by_magnitude(percent_to_keep, weight):\n",
|
||
" mask = _bottom_k_mask(percent_to_keep, np.abs(weight.flatten()))\n",
|
||
"\n",
|
||
" return mask.reshape(weight.shape)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "uAPa99KakJDB"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@title Initialisers\n",
|
||
"\n",
|
||
"class PowerPropVarianceScaling(snt.initializers.VarianceScaling):\n",
|
||
"\n",
|
||
" def __init__(self, alpha, *args, **kwargs):\n",
|
||
" super(PowerPropVarianceScaling, self).__init__(*args, **kwargs)\n",
|
||
" self._alpha = alpha\n",
|
||
"\n",
|
||
" def __call__(self, shape, dtype):\n",
|
||
" u = super(PowerPropVarianceScaling, self).__call__(shape, dtype).numpy()\n",
|
||
"\n",
|
||
" return tf.sign(u) * tf.pow(tf.abs(u), 1.0/self._alpha)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "UXtgeSdPgufJ"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@title Models\n",
|
||
"\n",
|
||
"class PowerPropLinear(snt.Linear):\n",
|
||
" \"\"\"Powerpropagation Linear module.\"\"\"\n",
|
||
" def __init__(self, alpha, *args, **kwargs):\n",
|
||
" super(PowerPropLinear, self).__init__(*args, **kwargs)\n",
|
||
" self._alpha = alpha\n",
|
||
"\n",
|
||
" def get_weights(self):\n",
|
||
" return tf.sign(self.w) * tf.pow(tf.abs(self.w), self._alpha)\n",
|
||
"\n",
|
||
" def __call__(self, inputs, mask=None):\n",
|
||
" self._initialize(inputs)\n",
|
||
" params = self.w * tf.pow(tf.abs(self.w), self._alpha-1)\n",
|
||
"\n",
|
||
" if mask is not None:\n",
|
||
" params *= mask\n",
|
||
"\n",
|
||
" outputs = tf.matmul(inputs, params) + self.b\n",
|
||
"\n",
|
||
" return outputs\n",
|
||
"\n",
|
||
"class MLP(snt.Module):\n",
|
||
" \"\"\"A multi-layer perceptron module.\"\"\"\n",
|
||
"\n",
|
||
" def __init__(self, alpha, w_init, output_sizes=[300, 100, 10], name='MLP'):\n",
|
||
"\n",
|
||
" super(MLP, self).__init__(name=name)\n",
|
||
" self._alpha = alpha\n",
|
||
" self._w_init = w_init\n",
|
||
"\n",
|
||
" self._layers = []\n",
|
||
" for index, output_size in enumerate(output_sizes):\n",
|
||
" self._layers.append(\n",
|
||
" PowerPropLinear(\n",
|
||
" output_size=output_size,\n",
|
||
" alpha=alpha,\n",
|
||
" w_init=w_init,\n",
|
||
" name=\"linear_{}\".format(index)))\n",
|
||
"\n",
|
||
" def get_weights(self):\n",
|
||
" return [l.get_weights().numpy() for l in self._layers]\n",
|
||
"\n",
|
||
" def __call__(self, inputs, masks=None):\n",
|
||
" num_layers = len(self._layers)\n",
|
||
"\n",
|
||
" for i, layer in enumerate(self._layers):\n",
|
||
" if masks is not None:\n",
|
||
" inputs = layer(inputs, masks[i])\n",
|
||
" else:\n",
|
||
" inputs = layer(inputs)\n",
|
||
" if i \u003c (num_layers - 1):\n",
|
||
" inputs = tf.nn.relu(inputs)\n",
|
||
"\n",
|
||
" return inputs\n",
|
||
"\n",
|
||
"\n",
|
||
"class DensityNetwork(snt.Module):\n",
|
||
" \"\"\"Produces categorical distribution.\"\"\"\n",
|
||
"\n",
|
||
" def __init__(self, network=None, name=\"DensityNetwork\", *args, **kwargs):\n",
|
||
" super(DensityNetwork, self).__init__(name=name)\n",
|
||
" self._network = network\n",
|
||
"\n",
|
||
" def __call__(self, inputs, masks=None, *args, **kwargs):\n",
|
||
" outputs = self._network(inputs, masks, *args, **kwargs)\n",
|
||
"\n",
|
||
" return tfp.distributions.Categorical(logits=outputs), outputs\n",
|
||
"\n",
|
||
" def trainable_variables(self):\n",
|
||
" return self._network.trainable_variables\n",
|
||
"\n",
|
||
" def get_weights(self):\n",
|
||
" return self._network.get_weights()\n",
|
||
"\n",
|
||
" def loss(self, inputs, targets, masks=None, *args, **kwargs):\n",
|
||
" dist, logits = self.__call__(\n",
|
||
" inputs, masks, *args, **kwargs)\n",
|
||
" loss = -tf.reduce_mean(dist.log_prob(targets))\n",
|
||
"\n",
|
||
" accuracy = tf.reduce_mean(\n",
|
||
" tf.cast(tf.equal(tf.argmax(logits, axis=1), targets), tf.float32))\n",
|
||
"\n",
|
||
" return loss, {'loss': loss, 'acc': accuracy}"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "J8M6PJ4oNK0X"
|
||
},
|
||
"source": [
|
||
"### Training"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "9TwBkiSmKWCz"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@title Training configuration\n",
|
||
"\n",
|
||
"model_seed = 0 #@param\n",
|
||
"\n",
|
||
"alphas = [1.0, 2.0, 3.0, 4.0, 5.0] #@param\n",
|
||
"init_distribution = 'truncated_normal'\n",
|
||
"init_mode = 'fan_in'\n",
|
||
"init_scale = 1.0\n",
|
||
"\n",
|
||
"# Fixed values taken from the Lottery Ticket Hypothesis paper\n",
|
||
"train_batch_size = 60\n",
|
||
"num_train_steps = 50000\n",
|
||
"learning_rate = 0.1\n",
|
||
"\n",
|
||
"report_interval = 2500\n",
|
||
"\n",
|
||
"tf.random.set_seed(model_seed)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"cellView": "form",
|
||
"id": "JGFVxUjCRe3v"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@title Get data\n",
|
||
"\n",
|
||
"(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()\n",
|
||
"\n",
|
||
"train_x = train_x.reshape([60000, 784]).astype(np.float32) / 255.0\n",
|
||
"test_x = test_x.reshape([10000, 784]).astype(np.float32) / 255.0\n",
|
||
"\n",
|
||
"train_y = train_y.astype(np.int64)\n",
|
||
"test_y = test_y.astype(np.int64)\n",
|
||
"\n",
|
||
"# Reserve some data for a validation set\n",
|
||
"valid_x = train_x[-5000:]\n",
|
||
"valid_y = train_y[-5000:]\n",
|
||
"train_x = train_x[:-5000]\n",
|
||
"train_y = train_y[:-5000]\n",
|
||
"\n",
|
||
"train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))\n",
|
||
"\n",
|
||
"# Sample random batches of data from the entire training set\n",
|
||
"train_iterator = iter(\n",
|
||
" train_dataset.repeat().shuffle(10000).batch(train_batch_size))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"cellView": "form",
|
||
"id": "nWemRI-x4FEa"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@title Training set-up\n",
|
||
"\n",
|
||
"model_types = []\n",
|
||
"models = []\n",
|
||
"n_models = len(alphas)\n",
|
||
"\n",
|
||
"for alpha in alphas:\n",
|
||
" w_init = PowerPropVarianceScaling(alpha)\n",
|
||
" models.append(DensityNetwork(MLP(alpha=alpha, w_init=w_init)))\n",
|
||
" if alpha \u003e 1.0:\n",
|
||
" model_types.append('Powerprop. ($\\\\alpha={}$)'.format(alpha))\n",
|
||
" else:\n",
|
||
" model_types.append('Baseline')\n",
|
||
"\n",
|
||
"# Initalise variables\n",
|
||
"for m in models:\n",
|
||
" m(valid_x)\n",
|
||
"\n",
|
||
"initial_weights = [m.get_weights() for m in models]\n",
|
||
"\n",
|
||
"\n",
|
||
"optimizers = [snt.optimizers.SGD(learning_rate=learning_rate)\n",
|
||
" for _ in range(n_models)]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"cellView": "form",
|
||
"id": "xvtAESbqFGaQ"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@title Training loop\n",
|
||
"\n",
|
||
"all_train_stats = [[] for _ in range(n_models)]\n",
|
||
"all_valid_stats = [[] for _ in range(n_models)]\n",
|
||
"\n",
|
||
"for step in range(num_train_steps+1):\n",
|
||
" train_x_batch, train_y_batch = next(train_iterator)\n",
|
||
"\n",
|
||
" for m_id, model in enumerate(models):\n",
|
||
" all_train_stats[m_id].append(\n",
|
||
" train_fn(model, train_x_batch, train_y_batch, optimizers[m_id]))\n",
|
||
"\n",
|
||
" if step % report_interval == 0:\n",
|
||
" for m_id, model in enumerate(models):\n",
|
||
" print('[Train Step {}, Alpha {}] Loss: {:1.3f}. Acc: {:1.3f}'.format(\n",
|
||
" step, \n",
|
||
" alphas[m_id], \n",
|
||
" all_train_stats[m_id][-1]['loss'], \n",
|
||
" all_train_stats[m_id][-1]['acc']))\n",
|
||
" all_valid_stats[m_id].append(eval_fn(model, valid_x, valid_y))\n",
|
||
"\n",
|
||
" print('[Eval Step {}, Alpha {}] Loss: {:1.3f}. Acc: {:1.3f}'.format(\n",
|
||
" step, \n",
|
||
" alphas[m_id], \n",
|
||
" all_valid_stats[m_id][-1]['loss'], \n",
|
||
" all_valid_stats[m_id][-1]['acc']))\n",
|
||
" print('---')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"id": "NDBNpC_4piOg"
|
||
},
|
||
"source": [
|
||
"### Results"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"cellView": "form",
|
||
"id": "dG7tyvmOQxeX"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"#@title Pruning\n",
|
||
"\n",
|
||
"final_weights = [m.get_weights() for m in models]\n",
|
||
"\n",
|
||
"eval_at_sparsity_level = np.geomspace(0.01, 1.0, 20).tolist()\n",
|
||
"acc_at_sparsity = [[] for _ in range(n_models)]\n",
|
||
"\n",
|
||
"for p_to_use in eval_at_sparsity_level:\n",
|
||
"\n",
|
||
" # Half the sparsity at output layer\n",
|
||
" percent = 2*[p_to_use] + [min(1.0, p_to_use*2)]\n",
|
||
"\n",
|
||
" for m_id, model_to_use in enumerate(models):\n",
|
||
" masks = []\n",
|
||
" for i, w in enumerate(final_weights[m_id]):\n",
|
||
" masks.append(prune_by_magnitude(percent[i], w))\n",
|
||
"\n",
|
||
" _, stats = model_to_use.loss(test_x, test_y, masks=masks)\n",
|
||
"\n",
|
||
" acc_at_sparsity[m_id].append(stats['acc'].numpy())\n",
|
||
" print(' Performance @ {:1.0f}% of weights [Alpha {}]: Acc {:1.3f} NLL {:1.3f} '.format(\n",
|
||
" 100*p_to_use, alphas[m_id], stats['acc'], stats['loss']))\n",
|
||
" print('---')"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"cellView": "form",
|
||
"id": "nv8u662-4yug"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"sns.set_style(\"whitegrid\")\n",
|
||
"sns.set_context(\"paper\")\n",
|
||
"\n",
|
||
"#@title Plot\n",
|
||
"f, ax = plt.subplots(1, 1, figsize=(7,5))\n",
|
||
"\n",
|
||
"for acc, label in zip(acc_at_sparsity, model_types):\n",
|
||
" ax.plot(eval_at_sparsity_level, acc, label=label, marker='o', lw=2)\n",
|
||
"\n",
|
||
"ax.set_xscale('log')\n",
|
||
"ax.set_xlim([1.0, 0.01])\n",
|
||
"ax.set_ylim([0.0, 1.0])\n",
|
||
"ax.legend(frameon=False)\n",
|
||
"ax.set_xlabel('Weights Remaining (%)')\n",
|
||
"ax.set_ylabel('Test Accuracy (%)')\n",
|
||
"\n",
|
||
"sns.despine()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"id": "-NEnDx52w9NL"
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
""
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"accelerator": "GPU",
|
||
"colab": {
|
||
"collapsed_sections": [
|
||
"zNZE9eU_poe0",
|
||
"J8M6PJ4oNK0X",
|
||
"NDBNpC_4piOg"
|
||
],
|
||
"last_runtime": {
|
||
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
|
||
"kind": "private"
|
||
},
|
||
"name": "Powerpropagation",
|
||
"private_outputs": true,
|
||
"provenance": [
|
||
{
|
||
"file_id": "12kDm1gRUsFQaQUfmf1AzONNtg2Oxm96h",
|
||
"timestamp": 1621155153435
|
||
},
|
||
{
|
||
"file_id": "1m7j2OKVrTR0YkvMXwDIZRmZdvFlB_HMk",
|
||
"timestamp": 1615900826226
|
||
},
|
||
{
|
||
"file_id": "11JbbynSJgkt4n92oEKrrgfHzCQo-4sRp",
|
||
"timestamp": 1613740982607
|
||
},
|
||
{
|
||
"file_id": "1Pa7Szt1OuHZXczMcMxLE_UctRRtMIbkO",
|
||
"timestamp": 1612795019221
|
||
}
|
||
]
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"name": "python3"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0
|
||
}
|