Add a Colab implementing a dendritic gated network in numpy.

PiperOrigin-RevId: 360505886
This commit is contained in:
Eren Sezener
2021-03-02 21:45:42 +00:00
committed by Louise Deason
parent d1ae6a456f
commit 7e6fd889e4
@@ -0,0 +1,333 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "s54x-Gq4AiYb"
},
"source": [
"## Simple Dendritic Gated Networks in numpy\n",
"\n",
"This colab implements a Dendritic Gated Network (DGN) solving a regression (using square loss) or a binary classification problem (using Bernoulli log loss). \n",
"\n",
"See our paper titled \"A rapid and efficient learning rule for biological neural circuits\" for details of the DGN model.\n",
"\n",
"\n",
"Some implementation details:\n",
"- We utilize `sklearn.datasets.load_breast_cancer` for binary classification and `sklearn.datasets.load_diabetes` for regression.\n",
"- This code is meant for educational purposes only. It is not optimized for high-performance, both in terms of computational efficiency and quality of fit. \n",
"- Network is trained on 80% of the dataset and tested on the rest. Test MSE or log loss is reported at the end of each epoch.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jhiajfn0EAxE"
},
"outputs": [],
"source": [
"# Copyright 2021 DeepMind Technologies Limited. All rights reserved.\n",
"#\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nm-F_uZA0_T2"
},
"outputs": [],
"source": [
"import numpy as np\n",
"from sklearn import datasets\n",
"from sklearn import preprocessing\n",
"from sklearn import model_selection\n",
"from typing import List, Optional"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LOoiBATk1AgQ"
},
"source": [
"## Choose classification or regression"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FCjzwzwh0ycl"
},
"outputs": [],
"source": [
"do_classification = True # if False, does regression"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TA5VmSeV-GTc"
},
"source": [
"### Load dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qnzNZrzNk3Pl"
},
"outputs": [],
"source": [
"if do_classification:\n",
" features, targets = datasets.load_breast_cancer(return_X_y=True)\n",
"\n",
"else:\n",
" features, targets = datasets.load_diabetes(return_X_y=True)\n",
"\n",
"\n",
"x_train, x_test, y_train, y_test = model_selection.train_test_split(\n",
" features, targets, test_size=0.2, random_state=0)\n",
"input_dim = x_train.shape[-1]\n",
"\n",
"feature_encoder = preprocessing.StandardScaler()\n",
"x_train = feature_encoder.fit_transform(x_train)\n",
"x_test = feature_encoder.transform(x_test)\n",
"\n",
"if not do_classification:\n",
" target_encoder = preprocessing.StandardScaler()\n",
" y_train = np.squeeze(target_encoder.fit_transform(y_train[:, np.newaxis]))\n",
" y_test = np.squeeze(target_encoder.transform(y_test[:, np.newaxis]))\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LTQxvDcok86S"
},
"source": [
"## DGN inference/update"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "F6Yt_tw0lmf1"
},
"outputs": [],
"source": [
"def step_square_loss(inputs: np.ndarray,\n",
" weights: List[np.ndarray],\n",
" hyperplanes: List[np.ndarray],\n",
" hyperplane_bias_magnitude: Optional[float] = 1.,\n",
" learning_rate: Optional[float] = 1e-5,\n",
" target: Optional[float] = None,\n",
" update: bool = False,\n",
" ):\n",
" \"\"\"Implements a DGN inference/update using square loss.\"\"\"\n",
" r_in = inputs\n",
" side_info = np.hstack([hyperplane_bias_magnitude, inputs])\n",
"\n",
" for w, h in zip(weights, hyperplanes): # loop over layers\n",
" r_in = np.hstack([1., r_in]) # add biases\n",
" gate_values = np.heaviside(h.dot(side_info), 0).astype(bool)\n",
" effective_weights = gate_values.dot(w).sum(axis=1)\n",
" r_out = effective_weights.dot(r_in)\n",
"\n",
" if update:\n",
" grad = (r_out[:, None] - target) * r_in[None]\n",
" w -= learning_rate * gate_values[:, :, None] * grad[:, None]\n",
"\n",
" r_in = r_out\n",
" r_out = r_out[0]\n",
" loss = (target - r_out)**2 / 2\n",
" return r_out, loss\n",
"\n",
"def sigmoid(x): # numerically stable sigmoid\n",
" return np.exp(-np.logaddexp(0, -x))\n",
"\n",
"def inverse_sigmoid(x):\n",
" return np.log(x/(1-x))\n",
"\n",
"def step_bernoulli(inputs: np.ndarray,\n",
" weights: List[np.ndarray],\n",
" hyperplanes: List[np.ndarray],\n",
" hyperplane_bias_magnitude: Optional[float] = 1.,\n",
" learning_rate: Optional[float] = 1e-5,\n",
" epsilon: float = 0.01,\n",
" target: Optional[float] = None,\n",
" update: bool = False,\n",
" ):\n",
" \"\"\"Implements a DGN inference/update using Bernoulli log loss.\"\"\"\n",
" r_in = np.clip(sigmoid(inputs), epsilon, 1-epsilon)\n",
" side_info = np.hstack([hyperplane_bias_magnitude, inputs])\n",
"\n",
" for w, h in zip(weights, hyperplanes): # loop over layers\n",
" r_in = np.hstack([sigmoid(1.), r_in]) # add biases\n",
" h_in = inverse_sigmoid(r_in)\n",
" gate_values = np.heaviside(h.dot(side_info), 0).astype(bool)\n",
" effective_weights = gate_values.dot(w).sum(axis=1)\n",
" h_out = effective_weights.dot(h_in)\n",
" r_out = np.clip(sigmoid(h_out), epsilon, 1 - epsilon)\n",
"\n",
" if update:\n",
" update_indicator = np.logical_and(r_out \u003c 1 - epsilon, r_out \u003e epsilon)\n",
" grad = (r_out[:, None] - target) * h_in[None] * update_indicator[:, None]\n",
" w -= learning_rate * gate_values[:, :, None] * grad[:, None]\n",
"\n",
" r_in = r_out\n",
"\n",
" r_out = r_out[0]\n",
" loss = -(target * r_out + (1 - target) * (1 - r_out))\n",
" return r_out, loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0B7wSn3Azcfb"
},
"outputs": [],
"source": [
"def forward_pass(step_fn, x, y, weights, hyperplanes, learning_rate, update):\n",
" losses, outputs = [], []\n",
" for x_i, y_i in zip(x, y):\n",
" y, l = step_fn(x_i, weights, hyperplanes, target=y_i,\n",
" learning_rate=learning_rate, update=update)\n",
" losses.append(l)\n",
" outputs.append(y)\n",
" return np.mean(losses), np.array(outputs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "41aHT8G0lsuu"
},
"source": [
"## Define architecture\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "WSbPuwzFvV2N"
},
"outputs": [],
"source": [
"# number of neurons per layer, the last element must be 1\n",
"num_neurons = np.array([100, 10, 1])\n",
"num_branches = 20 # number of dendritic brancher per neuron"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "gTk1YDXV-xoD"
},
"source": [
"## Initialise weights and gating parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Uek-2I5IlyN3"
},
"outputs": [],
"source": [
"num_inputs = np.hstack([input_dim + 1, num_neurons[:-1] + 1]) # 1 for the bias\n",
"weights_ = [np.zeros((num_neuron, num_branches, num_input))\n",
" for num_neuron, num_input in zip(num_neurons, num_inputs)]\n",
"hyperplanes_ = [np.random.normal(0, 1, size=(num_neuron, num_branches, input_dim + 1))\n",
" for num_neuron in num_neurons]\n",
"# By default, the weight parameters are drawn from a normalised Gaussian:\n",
"hyperplanes_ = [h_ / np.linalg.norm(h_[:, :, :-1], axis=(1, 2))[:, None, None]\n",
" for h_ in hyperplanes_]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Dy1XUdaSm0ID"
},
"source": [
"## Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wublBSqiucQ-"
},
"outputs": [],
"source": [
"if do_classification:\n",
" n_epochs = 3\n",
" learning_rate_const = 1e-4\n",
" step = step_bernoulli\n",
"else:\n",
" n_epochs = 10\n",
" learning_rate_const = 1e-5\n",
" step = step_square_loss\n",
"\n",
"for epoch in range(0, n_epochs):\n",
" train_loss, train_pred = forward_pass(\n",
" step, x_train, y_train, weights_,\n",
" hyperplanes_, learning_rate_const, update=True)\n",
" test_loss, test_pred = forward_pass(\n",
" step, x_test, y_test, weights_, hyperplanes_, learning_rate_const, update=False)\n",
" print('epoch: {:d}, test loss: {:.3f} (train_loss: {:.3f})'.format(\n",
" epoch, np.mean(test_loss), np.mean(train_loss)))\n",
"\n",
" if do_classification:\n",
" accuracy = 1 - np.mean(np.logical_xor(np.round(test_pred), y_test))\n",
" print('test accuracy: {:.3f}'.format(accuracy))\n"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"last_runtime": {
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
"kind": "private"
},
"name": "tp_dendritic_gated_network.ipynb",
"provenance": [
{
"file_id": "1lzQUssVJpeziFs1fdBHueD7DqNp6lkVK",
"timestamp": 1614705435731
}
]
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}