From 7e6fd889e49bba2abc203c5faf70837d6126e42a Mon Sep 17 00:00:00 2001 From: Eren Sezener Date: Tue, 2 Mar 2021 21:45:42 +0000 Subject: [PATCH] Add a Colab implementing a dendritic gated network in numpy. PiperOrigin-RevId: 360505886 --- .../colabs/dendritic_gated_network.ipynb | 333 ++++++++++++++++++ 1 file changed, 333 insertions(+) create mode 100644 gated_linear_networks/colabs/dendritic_gated_network.ipynb diff --git a/gated_linear_networks/colabs/dendritic_gated_network.ipynb b/gated_linear_networks/colabs/dendritic_gated_network.ipynb new file mode 100644 index 0000000..2de2f09 --- /dev/null +++ b/gated_linear_networks/colabs/dendritic_gated_network.ipynb @@ -0,0 +1,333 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "s54x-Gq4AiYb" + }, + "source": [ + "## Simple Dendritic Gated Networks in numpy\n", + "\n", + "This colab implements a Dendritic Gated Network (DGN) solving a regression (using square loss) or a binary classification problem (using Bernoulli log loss). \n", + "\n", + "See our paper titled \"A rapid and efficient learning rule for biological neural circuits\" for details of the DGN model.\n", + "\n", + "\n", + "Some implementation details:\n", + "- We utilize `sklearn.datasets.load_breast_cancer` for binary classification and `sklearn.datasets.load_diabetes` for regression.\n", + "- This code is meant for educational purposes only. It is not optimized for high-performance, both in terms of computational efficiency and quality of fit. \n", + "- Network is trained on 80% of the dataset and tested on the rest. Test MSE or log loss is reported at the end of each epoch.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "jhiajfn0EAxE" + }, + "outputs": [], + "source": [ + "# Copyright 2021 DeepMind Technologies Limited. All rights reserved.\n", + "#\n", + "#\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "nm-F_uZA0_T2" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from sklearn import datasets\n", + "from sklearn import preprocessing\n", + "from sklearn import model_selection\n", + "from typing import List, Optional" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LOoiBATk1AgQ" + }, + "source": [ + "## Choose classification or regression" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FCjzwzwh0ycl" + }, + "outputs": [], + "source": [ + "do_classification = True # if False, does regression" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TA5VmSeV-GTc" + }, + "source": [ + "### Load dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qnzNZrzNk3Pl" + }, + "outputs": [], + "source": [ + "if do_classification:\n", + " features, targets = datasets.load_breast_cancer(return_X_y=True)\n", + "\n", + "else:\n", + " features, targets = datasets.load_diabetes(return_X_y=True)\n", + "\n", + "\n", + "x_train, x_test, y_train, y_test = model_selection.train_test_split(\n", + " features, targets, test_size=0.2, random_state=0)\n", + "input_dim = x_train.shape[-1]\n", + "\n", + "feature_encoder = preprocessing.StandardScaler()\n", + "x_train = feature_encoder.fit_transform(x_train)\n", + "x_test = feature_encoder.transform(x_test)\n", + "\n", + "if not do_classification:\n", + " target_encoder = preprocessing.StandardScaler()\n", + " y_train = np.squeeze(target_encoder.fit_transform(y_train[:, np.newaxis]))\n", + " y_test = np.squeeze(target_encoder.transform(y_test[:, np.newaxis]))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LTQxvDcok86S" + }, + "source": [ + "## DGN inference/update" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "F6Yt_tw0lmf1" + }, + "outputs": [], + "source": [ + "def step_square_loss(inputs: np.ndarray,\n", + " weights: List[np.ndarray],\n", + " hyperplanes: List[np.ndarray],\n", + " hyperplane_bias_magnitude: Optional[float] = 1.,\n", + " learning_rate: Optional[float] = 1e-5,\n", + " target: Optional[float] = None,\n", + " update: bool = False,\n", + " ):\n", + " \"\"\"Implements a DGN inference/update using square loss.\"\"\"\n", + " r_in = inputs\n", + " side_info = np.hstack([hyperplane_bias_magnitude, inputs])\n", + "\n", + " for w, h in zip(weights, hyperplanes): # loop over layers\n", + " r_in = np.hstack([1., r_in]) # add biases\n", + " gate_values = np.heaviside(h.dot(side_info), 0).astype(bool)\n", + " effective_weights = gate_values.dot(w).sum(axis=1)\n", + " r_out = effective_weights.dot(r_in)\n", + "\n", + " if update:\n", + " grad = (r_out[:, None] - target) * r_in[None]\n", + " w -= learning_rate * gate_values[:, :, None] * grad[:, None]\n", + "\n", + " r_in = r_out\n", + " r_out = r_out[0]\n", + " loss = (target - r_out)**2 / 2\n", + " return r_out, loss\n", + "\n", + "def sigmoid(x): # numerically stable sigmoid\n", + " return np.exp(-np.logaddexp(0, -x))\n", + "\n", + "def inverse_sigmoid(x):\n", + " return np.log(x/(1-x))\n", + "\n", + "def step_bernoulli(inputs: np.ndarray,\n", + " weights: List[np.ndarray],\n", + " hyperplanes: List[np.ndarray],\n", + " hyperplane_bias_magnitude: Optional[float] = 1.,\n", + " learning_rate: Optional[float] = 1e-5,\n", + " epsilon: float = 0.01,\n", + " target: Optional[float] = None,\n", + " update: bool = False,\n", + " ):\n", + " \"\"\"Implements a DGN inference/update using Bernoulli log loss.\"\"\"\n", + " r_in = np.clip(sigmoid(inputs), epsilon, 1-epsilon)\n", + " side_info = np.hstack([hyperplane_bias_magnitude, inputs])\n", + "\n", + " for w, h in zip(weights, hyperplanes): # loop over layers\n", + " r_in = np.hstack([sigmoid(1.), r_in]) # add biases\n", + " h_in = inverse_sigmoid(r_in)\n", + " gate_values = np.heaviside(h.dot(side_info), 0).astype(bool)\n", + " effective_weights = gate_values.dot(w).sum(axis=1)\n", + " h_out = effective_weights.dot(h_in)\n", + " r_out = np.clip(sigmoid(h_out), epsilon, 1 - epsilon)\n", + "\n", + " if update:\n", + " update_indicator = np.logical_and(r_out \u003c 1 - epsilon, r_out \u003e epsilon)\n", + " grad = (r_out[:, None] - target) * h_in[None] * update_indicator[:, None]\n", + " w -= learning_rate * gate_values[:, :, None] * grad[:, None]\n", + "\n", + " r_in = r_out\n", + "\n", + " r_out = r_out[0]\n", + " loss = -(target * r_out + (1 - target) * (1 - r_out))\n", + " return r_out, loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0B7wSn3Azcfb" + }, + "outputs": [], + "source": [ + "def forward_pass(step_fn, x, y, weights, hyperplanes, learning_rate, update):\n", + " losses, outputs = [], []\n", + " for x_i, y_i in zip(x, y):\n", + " y, l = step_fn(x_i, weights, hyperplanes, target=y_i,\n", + " learning_rate=learning_rate, update=update)\n", + " losses.append(l)\n", + " outputs.append(y)\n", + " return np.mean(losses), np.array(outputs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "41aHT8G0lsuu" + }, + "source": [ + "## Define architecture\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WSbPuwzFvV2N" + }, + "outputs": [], + "source": [ + "# number of neurons per layer, the last element must be 1\n", + "num_neurons = np.array([100, 10, 1])\n", + "num_branches = 20 # number of dendritic brancher per neuron" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gTk1YDXV-xoD" + }, + "source": [ + "## Initialise weights and gating parameters" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Uek-2I5IlyN3" + }, + "outputs": [], + "source": [ + "num_inputs = np.hstack([input_dim + 1, num_neurons[:-1] + 1]) # 1 for the bias\n", + "weights_ = [np.zeros((num_neuron, num_branches, num_input))\n", + " for num_neuron, num_input in zip(num_neurons, num_inputs)]\n", + "hyperplanes_ = [np.random.normal(0, 1, size=(num_neuron, num_branches, input_dim + 1))\n", + " for num_neuron in num_neurons]\n", + "# By default, the weight parameters are drawn from a normalised Gaussian:\n", + "hyperplanes_ = [h_ / np.linalg.norm(h_[:, :, :-1], axis=(1, 2))[:, None, None]\n", + " for h_ in hyperplanes_]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Dy1XUdaSm0ID" + }, + "source": [ + "## Train" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wublBSqiucQ-" + }, + "outputs": [], + "source": [ + "if do_classification:\n", + " n_epochs = 3\n", + " learning_rate_const = 1e-4\n", + " step = step_bernoulli\n", + "else:\n", + " n_epochs = 10\n", + " learning_rate_const = 1e-5\n", + " step = step_square_loss\n", + "\n", + "for epoch in range(0, n_epochs):\n", + " train_loss, train_pred = forward_pass(\n", + " step, x_train, y_train, weights_,\n", + " hyperplanes_, learning_rate_const, update=True)\n", + " test_loss, test_pred = forward_pass(\n", + " step, x_test, y_test, weights_, hyperplanes_, learning_rate_const, update=False)\n", + " print('epoch: {:d}, test loss: {:.3f} (train_loss: {:.3f})'.format(\n", + " epoch, np.mean(test_loss), np.mean(train_loss)))\n", + "\n", + " if do_classification:\n", + " accuracy = 1 - np.mean(np.logical_xor(np.round(test_pred), y_test))\n", + " print('test accuracy: {:.3f}'.format(accuracy))\n" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "last_runtime": { + "build_target": "//learning/deepmind/dm_python:dm_notebook3", + "kind": "private" + }, + "name": "tp_dendritic_gated_network.ipynb", + "provenance": [ + { + "file_id": "1lzQUssVJpeziFs1fdBHueD7DqNp6lkVK", + "timestamp": 1614705435731 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}