mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Update DGN demo according to the most recent pseudocode.
PiperOrigin-RevId: 361780732
This commit is contained in:
committed by
Louise Deason
parent
e5bc0cdc10
commit
a5d5592972
@@ -0,0 +1,55 @@
|
||||
# Colabs
|
||||
|
||||
## Dendritic Gated Networks
|
||||
|
||||
`dendritic_gated_network.ipynb` implements a Dendritic Gated Network (DGN) solving a regression (using quadratic loss) or a binary classification problem (using Bernoulli log loss).
|
||||
|
||||
See our paper titled "A rapid and efficient learning rule for biological neural circuits" for details of the DGN model.
|
||||
|
||||
|
||||
### Instructions for running the `dendritic_gated_network.ipynb` colab/notebook.
|
||||
|
||||
We suggest running the [dendritic_gated_network.ipynb](https://github.com/deepmind/deepmind-research/blob/master/gated_linear_networks/colabs/dendritic_gated_network.ipynb) notebook using Google Colaboratory (or Colab). All the dependencies are included by default in Colab cloud runtimes (last tested on the 8th of March, 2021). See https://research.google.com/colaboratory/faq.html for web browser requirements. The notebook runs for about a minute using the free tier runtimes.
|
||||
|
||||
The code also runs in JupyterLab/JupyterNotebook (tested on Version 1.02).
|
||||
|
||||
1. Visit https://colab.research.google.com/
|
||||
2. Sign in with your Google account.
|
||||
3. Click on "File" and select "Open notebook".
|
||||
|
||||
4. Then you can
|
||||
* either open the notebook directly from GitHub:
|
||||
* Click on the GitHub tab
|
||||
* Paste https://github.com/deepmind/deepmind-research/blob/master/gated_linear_networks/colabs/dendritic_gated_network.ipynb into the URL section and click the search button. If the notebook does not open automatically, then select the correct notebook from the list provided.
|
||||
* or upload the provided notebook manually:
|
||||
* Click on the Upload tab
|
||||
* Choose or drag dendritic_gated_network.ipynb
|
||||
5. Click Connect (top right corner) to connect to a run time
|
||||
6. Click on the "Runtime" tab and select "Run all" to run all the cells.
|
||||
|
||||
### Expected outputs
|
||||
We provide the expected outputs below.
|
||||
|
||||
Classification (do_classification = True):
|
||||
|
||||
```
|
||||
epoch: 0, test loss: 0.693 (train: 0.693), test accuracy: 0.412 (train: 0.363)
|
||||
epoch: 1, test loss: 0.099 (train: 0.196), test accuracy: 0.974 (train: 0.963)
|
||||
epoch: 2, test loss: 0.095 (train: 0.079), test accuracy: 0.974 (train: 0.978)
|
||||
epoch: 3, test loss: 0.099 (train: 0.070), test accuracy: 0.974 (train: 0.982)
|
||||
```
|
||||
Regression (do_classification = False):
|
||||
|
||||
```
|
||||
epoch: 0, test loss: 0.419 (train: 0.500)
|
||||
epoch: 1, test loss: 0.388 (train: 0.486)
|
||||
epoch: 2, test loss: 0.354 (train: 0.439)
|
||||
epoch: 3, test loss: 0.328 (train: 0.400)
|
||||
epoch: 4, test loss: 0.310 (train: 0.369)
|
||||
epoch: 5, test loss: 0.297 (train: 0.344)
|
||||
epoch: 6, test loss: 0.287 (train: 0.324)
|
||||
epoch: 7, test loss: 0.281 (train: 0.308)
|
||||
epoch: 8, test loss: 0.277 (train: 0.296)
|
||||
epoch: 9, test loss: 0.275 (train: 0.285)
|
||||
epoch: 10, test loss: 0.273 (train: 0.277)
|
||||
```
|
||||
@@ -8,16 +8,15 @@
|
||||
"source": [
|
||||
"## Simple Dendritic Gated Networks in numpy\n",
|
||||
"\n",
|
||||
"This colab implements a Dendritic Gated Network (DGN) solving a regression (using square loss) or a binary classification problem (using Bernoulli log loss). \n",
|
||||
"This colab implements a Dendritic Gated Network (DGN) solving a regression (using quadratic loss) or a binary classification problem (using Bernoulli log loss).\n",
|
||||
"\n",
|
||||
"See our paper titled \"A rapid and efficient learning rule for biological neural circuits\" for details of the DGN model.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Some implementation details:\n",
|
||||
"- We utilize `sklearn.datasets.load_breast_cancer` for binary classification and `sklearn.datasets.load_diabetes` for regression.\n",
|
||||
"- This code is meant for educational purposes only. It is not optimized for high-performance, both in terms of computational efficiency and quality of fit. \n",
|
||||
"- Network is trained on 80% of the dataset and tested on the rest. Test MSE or log loss is reported at the end of each epoch.\n",
|
||||
"\n"
|
||||
"- This code is meant for educational purposes only. It is not optimized for high-performance, both in terms of computational efficiency and quality of fit.\n",
|
||||
"- Network is trained on 80% of the dataset and tested on the rest. For classification, we report log loss (negative log likelihood) and accuracy (percentage of correctly identified labels). For regression, we report MSE expressed in units of target variance."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -98,23 +97,23 @@
|
||||
"source": [
|
||||
"if do_classification:\n",
|
||||
" features, targets = datasets.load_breast_cancer(return_X_y=True)\n",
|
||||
"\n",
|
||||
"else:\n",
|
||||
" features, targets = datasets.load_diabetes(return_X_y=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"x_train, x_test, y_train, y_test = model_selection.train_test_split(\n",
|
||||
" features, targets, test_size=0.2, random_state=0)\n",
|
||||
"input_dim = x_train.shape[-1]\n",
|
||||
"n_features = x_train.shape[-1]\n",
|
||||
"\n",
|
||||
"# Input features are centered and scaled to unit variance:\n",
|
||||
"feature_encoder = preprocessing.StandardScaler()\n",
|
||||
"x_train = feature_encoder.fit_transform(x_train)\n",
|
||||
"x_test = feature_encoder.transform(x_test)\n",
|
||||
"\n",
|
||||
"if not do_classification:\n",
|
||||
" # Continuous targets are centered and scaled to unit variance:\n",
|
||||
" target_encoder = preprocessing.StandardScaler()\n",
|
||||
" y_train = np.squeeze(target_encoder.fit_transform(y_train[:, np.newaxis]))\n",
|
||||
" y_test = np.squeeze(target_encoder.transform(y_test[:, np.newaxis]))\n"
|
||||
" y_test = np.squeeze(target_encoder.transform(y_test[:, np.newaxis]))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -157,7 +156,6 @@
|
||||
" w -= learning_rate * gate_values[:, :, None] * grad[:, None]\n",
|
||||
"\n",
|
||||
" r_in = r_out\n",
|
||||
" r_out = r_out[0]\n",
|
||||
" loss = (target - r_out)**2 / 2\n",
|
||||
" return r_out, loss\n",
|
||||
"\n",
|
||||
@@ -186,17 +184,14 @@
|
||||
" gate_values = np.heaviside(h.dot(side_info), 0).astype(bool)\n",
|
||||
" effective_weights = gate_values.dot(w).sum(axis=1)\n",
|
||||
" h_out = effective_weights.dot(h_in)\n",
|
||||
" r_out = np.clip(sigmoid(h_out), epsilon, 1 - epsilon)\n",
|
||||
"\n",
|
||||
" r_out_unclipped = sigmoid(h_out)\n",
|
||||
" r_out = np.clip(r_out_unclipped, epsilon, 1 - epsilon)\n",
|
||||
" if update:\n",
|
||||
" update_indicator = np.logical_and(r_out \u003c 1 - epsilon, r_out \u003e epsilon)\n",
|
||||
" update_indicator = np.abs(target - r_out_unclipped) \u003e epsilon\n",
|
||||
" grad = (r_out[:, None] - target) * h_in[None] * update_indicator[:, None]\n",
|
||||
" w -= learning_rate * gate_values[:, :, None] * grad[:, None]\n",
|
||||
"\n",
|
||||
" r_in = r_out\n",
|
||||
"\n",
|
||||
" r_out = r_out[0]\n",
|
||||
" loss = -(target * r_out + (1 - target) * (1 - r_out))\n",
|
||||
" loss = - (target * np.log(r_out) + (1 - target) * np.log(1 - r_out))\n",
|
||||
" return r_out, loss"
|
||||
]
|
||||
},
|
||||
@@ -209,13 +204,11 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def forward_pass(step_fn, x, y, weights, hyperplanes, learning_rate, update):\n",
|
||||
" losses, outputs = [], []\n",
|
||||
" for x_i, y_i in zip(x, y):\n",
|
||||
" y, l = step_fn(x_i, weights, hyperplanes, target=y_i,\n",
|
||||
" learning_rate=learning_rate, update=update)\n",
|
||||
" losses.append(l)\n",
|
||||
" outputs.append(y)\n",
|
||||
" return np.mean(losses), np.array(outputs)"
|
||||
" losses, outputs = np.zeros(len(y)), np.zeros(len(y))\n",
|
||||
" for i, (x_i, y_i) in enumerate(zip(x, y)):\n",
|
||||
" outputs[i], losses[i] = step_fn(x_i, weights, hyperplanes, target=y_i,\n",
|
||||
" learning_rate=learning_rate, update=update)\n",
|
||||
" return np.mean(losses), outputs"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -236,8 +229,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# number of neurons per layer, the last element must be 1\n",
|
||||
"num_neurons = np.array([100, 10, 1])\n",
|
||||
"num_branches = 20 # number of dendritic brancher per neuron"
|
||||
"n_neurons = np.array([100, 10, 1])\n",
|
||||
"n_branches = 20 # number of dendritic brancher per neuron"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -257,14 +250,19 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"num_inputs = np.hstack([input_dim + 1, num_neurons[:-1] + 1]) # 1 for the bias\n",
|
||||
"weights_ = [np.zeros((num_neuron, num_branches, num_input))\n",
|
||||
" for num_neuron, num_input in zip(num_neurons, num_inputs)]\n",
|
||||
"hyperplanes_ = [np.random.normal(0, 1, size=(num_neuron, num_branches, input_dim + 1))\n",
|
||||
" for num_neuron in num_neurons]\n",
|
||||
"n_inputs = np.hstack([n_features + 1, n_neurons[:-1] + 1]) # 1 for the bias\n",
|
||||
"dgn_weights = [np.zeros((n_neuron, n_branches, n_input))\n",
|
||||
" for n_neuron, n_input in zip(n_neurons, n_inputs)]\n",
|
||||
"\n",
|
||||
"# Fixing random seed for reproducibility:\n",
|
||||
"np.random.seed(12345)\n",
|
||||
"dgn_hyperplanes = [\n",
|
||||
" np.random.normal(0, 1, size=(n_neuron, n_branches, n_features + 1))\n",
|
||||
" for n_neuron in n_neurons]\n",
|
||||
"# By default, the weight parameters are drawn from a normalised Gaussian:\n",
|
||||
"hyperplanes_ = [h_ / np.linalg.norm(h_[:, :, :-1], axis=(1, 2))[:, None, None]\n",
|
||||
" for h_ in hyperplanes_]"
|
||||
"dgn_hyperplanes = [\n",
|
||||
" h_ / np.linalg.norm(h_[:, :, :-1], axis=(1, 2))[:, None, None]\n",
|
||||
" for h_ in dgn_hyperplanes]"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -285,26 +283,40 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if do_classification:\n",
|
||||
" eta = 1e-4\n",
|
||||
" n_epochs = 3\n",
|
||||
" learning_rate_const = 1e-4\n",
|
||||
" step = step_bernoulli\n",
|
||||
"else:\n",
|
||||
" eta = 1e-5\n",
|
||||
" n_epochs = 10\n",
|
||||
" learning_rate_const = 1e-5\n",
|
||||
" step = step_square_loss\n",
|
||||
"\n",
|
||||
"for epoch in range(0, n_epochs):\n",
|
||||
"if do_classification:\n",
|
||||
" step = step_bernoulli\n",
|
||||
"else:\n",
|
||||
" step = step_square_loss\n",
|
||||
"\n",
|
||||
"print('Training on {} problem for {} epochs with learning rate {}.'.format(\n",
|
||||
" ['regression', 'classification'][do_classification], n_epochs, eta))\n",
|
||||
"print('This may take a minute. Please be patient...')\n",
|
||||
"\n",
|
||||
"for epoch in range(0, n_epochs + 1):\n",
|
||||
" train_loss, train_pred = forward_pass(\n",
|
||||
" step, x_train, y_train, weights_,\n",
|
||||
" hyperplanes_, learning_rate_const, update=True)\n",
|
||||
" step, x_train, y_train, dgn_weights,\n",
|
||||
" dgn_hyperplanes, eta, update=(epoch \u003e 0))\n",
|
||||
"\n",
|
||||
" test_loss, test_pred = forward_pass(\n",
|
||||
" step, x_test, y_test, weights_, hyperplanes_, learning_rate_const, update=False)\n",
|
||||
" print('epoch: {:d}, test loss: {:.3f} (train_loss: {:.3f})'.format(\n",
|
||||
" epoch, np.mean(test_loss), np.mean(train_loss)))\n",
|
||||
" step, x_test, y_test, dgn_weights,\n",
|
||||
" dgn_hyperplanes, eta, update=False)\n",
|
||||
" to_print = 'epoch: {}, test loss: {:.3f} (train: {:.3f})'.format(\n",
|
||||
" epoch, test_loss, train_loss)\n",
|
||||
"\n",
|
||||
" if do_classification:\n",
|
||||
" accuracy = 1 - np.mean(np.logical_xor(np.round(test_pred), y_test))\n",
|
||||
" print('test accuracy: {:.3f}'.format(accuracy))\n"
|
||||
" accuracy_train = np.mean(np.round(train_pred) == y_train)\n",
|
||||
" accuracy = np.mean(np.round(test_pred) == y_test)\n",
|
||||
" to_print += ', test accuracy: {:.3f} (train: {:.3f})'.format(\n",
|
||||
" accuracy, accuracy_train)\n",
|
||||
" print(to_print)"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -315,7 +327,8 @@
|
||||
"build_target": "//learning/deepmind/dm_python:dm_notebook3",
|
||||
"kind": "private"
|
||||
},
|
||||
"name": "tp_dendritic_gated_network.ipynb",
|
||||
"name": "dendritic_gated_network.ipynb",
|
||||
"private_outputs": true,
|
||||
"provenance": [
|
||||
{
|
||||
"file_id": "1lzQUssVJpeziFs1fdBHueD7DqNp6lkVK",
|
||||
|
||||
Reference in New Issue
Block a user