mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Add a Colab implementing a dendritic gated network in numpy.
PiperOrigin-RevId: 360505886
This commit is contained in:
committed by
Louise Deason
parent
d1ae6a456f
commit
7e6fd889e4
@@ -0,0 +1,333 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "s54x-Gq4AiYb"
|
||||
},
|
||||
"source": [
|
||||
"## Simple Dendritic Gated Networks in numpy\n",
|
||||
"\n",
|
||||
"This colab implements a Dendritic Gated Network (DGN) solving a regression (using square loss) or a binary classification problem (using Bernoulli log loss). \n",
|
||||
"\n",
|
||||
"See our paper titled \"A rapid and efficient learning rule for biological neural circuits\" for details of the DGN model.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Some implementation details:\n",
|
||||
"- We utilize `sklearn.datasets.load_breast_cancer` for binary classification and `sklearn.datasets.load_diabetes` for regression.\n",
|
||||
"- This code is meant for educational purposes only. It is not optimized for high-performance, both in terms of computational efficiency and quality of fit. \n",
|
||||
"- Network is trained on 80% of the dataset and tested on the rest. Test MSE or log loss is reported at the end of each epoch.\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "jhiajfn0EAxE"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Copyright 2021 DeepMind Technologies Limited. All rights reserved.\n",
|
||||
"#\n",
|
||||
"#\n",
|
||||
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
||||
"# you may not use this file except in compliance with the License.\n",
|
||||
"# You may obtain a copy of the License at\n",
|
||||
"#\n",
|
||||
"# https://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
"#\n",
|
||||
"# Unless required by applicable law or agreed to in writing, software\n",
|
||||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
||||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
||||
"# See the License for the specific language governing permissions and\n",
|
||||
"# limitations under the License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "nm-F_uZA0_T2"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"from sklearn import datasets\n",
|
||||
"from sklearn import preprocessing\n",
|
||||
"from sklearn import model_selection\n",
|
||||
"from typing import List, Optional"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "LOoiBATk1AgQ"
|
||||
},
|
||||
"source": [
|
||||
"## Choose classification or regression"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "FCjzwzwh0ycl"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"do_classification = True # if False, does regression"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "TA5VmSeV-GTc"
|
||||
},
|
||||
"source": [
|
||||
"### Load dataset"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "qnzNZrzNk3Pl"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if do_classification:\n",
|
||||
" features, targets = datasets.load_breast_cancer(return_X_y=True)\n",
|
||||
"\n",
|
||||
"else:\n",
|
||||
" features, targets = datasets.load_diabetes(return_X_y=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"x_train, x_test, y_train, y_test = model_selection.train_test_split(\n",
|
||||
" features, targets, test_size=0.2, random_state=0)\n",
|
||||
"input_dim = x_train.shape[-1]\n",
|
||||
"\n",
|
||||
"feature_encoder = preprocessing.StandardScaler()\n",
|
||||
"x_train = feature_encoder.fit_transform(x_train)\n",
|
||||
"x_test = feature_encoder.transform(x_test)\n",
|
||||
"\n",
|
||||
"if not do_classification:\n",
|
||||
" target_encoder = preprocessing.StandardScaler()\n",
|
||||
" y_train = np.squeeze(target_encoder.fit_transform(y_train[:, np.newaxis]))\n",
|
||||
" y_test = np.squeeze(target_encoder.transform(y_test[:, np.newaxis]))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "LTQxvDcok86S"
|
||||
},
|
||||
"source": [
|
||||
"## DGN inference/update"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "F6Yt_tw0lmf1"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def step_square_loss(inputs: np.ndarray,\n",
|
||||
" weights: List[np.ndarray],\n",
|
||||
" hyperplanes: List[np.ndarray],\n",
|
||||
" hyperplane_bias_magnitude: Optional[float] = 1.,\n",
|
||||
" learning_rate: Optional[float] = 1e-5,\n",
|
||||
" target: Optional[float] = None,\n",
|
||||
" update: bool = False,\n",
|
||||
" ):\n",
|
||||
" \"\"\"Implements a DGN inference/update using square loss.\"\"\"\n",
|
||||
" r_in = inputs\n",
|
||||
" side_info = np.hstack([hyperplane_bias_magnitude, inputs])\n",
|
||||
"\n",
|
||||
" for w, h in zip(weights, hyperplanes): # loop over layers\n",
|
||||
" r_in = np.hstack([1., r_in]) # add biases\n",
|
||||
" gate_values = np.heaviside(h.dot(side_info), 0).astype(bool)\n",
|
||||
" effective_weights = gate_values.dot(w).sum(axis=1)\n",
|
||||
" r_out = effective_weights.dot(r_in)\n",
|
||||
"\n",
|
||||
" if update:\n",
|
||||
" grad = (r_out[:, None] - target) * r_in[None]\n",
|
||||
" w -= learning_rate * gate_values[:, :, None] * grad[:, None]\n",
|
||||
"\n",
|
||||
" r_in = r_out\n",
|
||||
" r_out = r_out[0]\n",
|
||||
" loss = (target - r_out)**2 / 2\n",
|
||||
" return r_out, loss\n",
|
||||
"\n",
|
||||
"def sigmoid(x): # numerically stable sigmoid\n",
|
||||
" return np.exp(-np.logaddexp(0, -x))\n",
|
||||
"\n",
|
||||
"def inverse_sigmoid(x):\n",
|
||||
" return np.log(x/(1-x))\n",
|
||||
"\n",
|
||||
"def step_bernoulli(inputs: np.ndarray,\n",
|
||||
" weights: List[np.ndarray],\n",
|
||||
" hyperplanes: List[np.ndarray],\n",
|
||||
" hyperplane_bias_magnitude: Optional[float] = 1.,\n",
|
||||
" learning_rate: Optional[float] = 1e-5,\n",
|
||||
" epsilon: float = 0.01,\n",
|
||||
" target: Optional[float] = None,\n",
|
||||
" update: bool = False,\n",
|
||||
" ):\n",
|
||||
" \"\"\"Implements a DGN inference/update using Bernoulli log loss.\"\"\"\n",
|
||||
" r_in = np.clip(sigmoid(inputs), epsilon, 1-epsilon)\n",
|
||||
" side_info = np.hstack([hyperplane_bias_magnitude, inputs])\n",
|
||||
"\n",
|
||||
" for w, h in zip(weights, hyperplanes): # loop over layers\n",
|
||||
" r_in = np.hstack([sigmoid(1.), r_in]) # add biases\n",
|
||||
" h_in = inverse_sigmoid(r_in)\n",
|
||||
" gate_values = np.heaviside(h.dot(side_info), 0).astype(bool)\n",
|
||||
" effective_weights = gate_values.dot(w).sum(axis=1)\n",
|
||||
" h_out = effective_weights.dot(h_in)\n",
|
||||
" r_out = np.clip(sigmoid(h_out), epsilon, 1 - epsilon)\n",
|
||||
"\n",
|
||||
" if update:\n",
|
||||
" update_indicator = np.logical_and(r_out \u003c 1 - epsilon, r_out \u003e epsilon)\n",
|
||||
" grad = (r_out[:, None] - target) * h_in[None] * update_indicator[:, None]\n",
|
||||
" w -= learning_rate * gate_values[:, :, None] * grad[:, None]\n",
|
||||
"\n",
|
||||
" r_in = r_out\n",
|
||||
"\n",
|
||||
" r_out = r_out[0]\n",
|
||||
" loss = -(target * r_out + (1 - target) * (1 - r_out))\n",
|
||||
" return r_out, loss"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "0B7wSn3Azcfb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def forward_pass(step_fn, x, y, weights, hyperplanes, learning_rate, update):\n",
|
||||
" losses, outputs = [], []\n",
|
||||
" for x_i, y_i in zip(x, y):\n",
|
||||
" y, l = step_fn(x_i, weights, hyperplanes, target=y_i,\n",
|
||||
" learning_rate=learning_rate, update=update)\n",
|
||||
" losses.append(l)\n",
|
||||
" outputs.append(y)\n",
|
||||
" return np.mean(losses), np.array(outputs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "41aHT8G0lsuu"
|
||||
},
|
||||
"source": [
|
||||
"## Define architecture\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "WSbPuwzFvV2N"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# number of neurons per layer, the last element must be 1\n",
|
||||
"num_neurons = np.array([100, 10, 1])\n",
|
||||
"num_branches = 20 # number of dendritic brancher per neuron"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "gTk1YDXV-xoD"
|
||||
},
|
||||
"source": [
|
||||
"## Initialise weights and gating parameters"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "Uek-2I5IlyN3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"num_inputs = np.hstack([input_dim + 1, num_neurons[:-1] + 1]) # 1 for the bias\n",
|
||||
"weights_ = [np.zeros((num_neuron, num_branches, num_input))\n",
|
||||
" for num_neuron, num_input in zip(num_neurons, num_inputs)]\n",
|
||||
"hyperplanes_ = [np.random.normal(0, 1, size=(num_neuron, num_branches, input_dim + 1))\n",
|
||||
" for num_neuron in num_neurons]\n",
|
||||
"# By default, the weight parameters are drawn from a normalised Gaussian:\n",
|
||||
"hyperplanes_ = [h_ / np.linalg.norm(h_[:, :, :-1], axis=(1, 2))[:, None, None]\n",
|
||||
" for h_ in hyperplanes_]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Dy1XUdaSm0ID"
|
||||
},
|
||||
"source": [
|
||||
"## Train"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "wublBSqiucQ-"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if do_classification:\n",
|
||||
" n_epochs = 3\n",
|
||||
" learning_rate_const = 1e-4\n",
|
||||
" step = step_bernoulli\n",
|
||||
"else:\n",
|
||||
" n_epochs = 10\n",
|
||||
" learning_rate_const = 1e-5\n",
|
||||
" step = step_square_loss\n",
|
||||
"\n",
|
||||
"for epoch in range(0, n_epochs):\n",
|
||||
" train_loss, train_pred = forward_pass(\n",
|
||||
" step, x_train, y_train, weights_,\n",
|
||||
" hyperplanes_, learning_rate_const, update=True)\n",
|
||||
" test_loss, test_pred = forward_pass(\n",
|
||||
" step, x_test, y_test, weights_, hyperplanes_, learning_rate_const, update=False)\n",
|
||||
" print('epoch: {:d}, test loss: {:.3f} (train_loss: {:.3f})'.format(\n",
|
||||
" epoch, np.mean(test_loss), np.mean(train_loss)))\n",
|
||||
"\n",
|
||||
" if do_classification:\n",
|
||||
" accuracy = 1 - np.mean(np.logical_xor(np.round(test_pred), y_test))\n",
|
||||
" print('test accuracy: {:.3f}'.format(accuracy))\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"collapsed_sections": [],
|
||||
"last_runtime": {
|
||||
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
|
||||
"kind": "private"
|
||||
},
|
||||
"name": "tp_dendritic_gated_network.ipynb",
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1lzQUssVJpeziFs1fdBHueD7DqNp6lkVK",
|
||||
"timestamp": 1614705435731
|
||||
}
|
||||
]
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
Reference in New Issue
Block a user