diff --git a/README.md b/README.md index 20d8df3..add2cd4 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ https://deepmind.com/research/publications/ ## Projects +* [Skilful precipitation nowcasting using deep generative models of radar](nowcasting), Nature 2021 * [Compute-Aided Design as Language](cadl) * [Encoders and ensembles for continual learning](continual_learning) * [Towards mental time travel: a hierarchical memory for reinforcement learning agents](hierarchical_transformer_memory) diff --git a/nowcasting/Open_sourced_dataset_and_model_snapshot_for_precipitation_nowcasting.ipynb b/nowcasting/Open_sourced_dataset_and_model_snapshot_for_precipitation_nowcasting.ipynb new file mode 100644 index 0000000..8e8968a --- /dev/null +++ b/nowcasting/Open_sourced_dataset_and_model_snapshot_for_precipitation_nowcasting.ipynb @@ -0,0 +1,1113 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "wFD0zFFyuHzH" + }, + "source": [ + "# Open-sourced dataset and model snapshot for precipitation nowcasting, accompanying the paper *Skillful Precipitation Nowcasting using Deep Generative Models of Radar, Ravuri et al. 2021.*\n", + "\n", + "This colab contains:\n", + "* Code to read the dataset using [Tensorflow 2](https://www.tensorflow.org/), with documentation of the available splits, variants and fields\n", + "* Example plots and animations of the data using [matplotlib](https://matplotlib.org/) and [cartopy](https://scitools.org.uk/cartopy/docs/latest/)\n", + "* A [TF-Hub](https://www.tensorflow.org/hub) snapshot of the model from the paper\n", + "* Example code to load this model and use it to make predictions.\n", + "\n", + "It has been tested in a public Google colab kernel." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d-H23Kuo7YM0" + }, + "source": [ + "## How to run this notebook\n", + "\n", + "All sections with the exception of 'Making predictions on a row from the full-frame test set (1536x1280)' can be evaluated on a free public Colab kernel. The final section requires more RAM than is available with a free kernel. To evaluate these cells you can either run your own local kernel (with \u003e= 24GB of RAM), or upgrade to Colab Pro.\n", + "\n", + "To launch a local colab kernel, please follow these [instructions](https://research.google.com/colaboratory/local-runtimes.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7V0C4sb4MEhx" + }, + "source": [ + "## License and attribution" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "N2SmJ8joMQ7G" + }, + "source": [ + "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n", + "\n", + "[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)\n", + "\n", + "Unless required by applicable law or agreed to in writing, software distributed\n", + "under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR\n", + "CONDITIONS OF ANY KIND, either express or implied. See the License for the\n", + "specific language governing permissions and limitations under the License.\n", + "\n", + "The datasets and the model snapshots associated with this colab are made available for use under the terms of the\n", + "[Creative Commons Attribution 4.0 International License](http://creativecommons.org/licenses/by/4.0/).\n", + "\n", + "This colab and the associated model snapshots are Copyright 2021 DeepMind Technologies Limited.\n", + "\n", + "The associated datasets contain public sector information licensed by the [Met Office](https://www.metoffice.gov.uk/) under the\n", + "[UK Open Government Licence v3.0](http://www.nationalarchives.gov.uk/doc/open-government-licence/version/3).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EE0_Q3CXv_wH" + }, + "source": [ + "## Library dependency installs and imports" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2AVIRsf0gcMd" + }, + "source": [ + "The following libraries are required. You can skip these `pip install` cells if your kernel already has them installed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vfV8LIhQgeFU" + }, + "outputs": [], + "source": [ + "!pip -q install tensorflow~=2.5.0 numpy~=1.19.5 matplotlib~=3.2.2 tensorflow_hub~=0.12.0 cartopy~=0.19.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qNvCVcQYitBK" + }, + "outputs": [], + "source": [ + "# Workaround for cartopy crashes due to the shapely installed by default in\n", + "# google colab kernel (https://github.com/anitagraser/movingpandas/issues/81):\n", + "!pip uninstall -y shapely\n", + "!pip install shapely --no-binary shapely" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rxRQBC3vjGNM" + }, + "source": [ + "## Imports:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "z_mTZ79PIw3j" + }, + "outputs": [], + "source": [ + "import datetime\n", + "import os\n", + "\n", + "import cartopy\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import animation\n", + "import numpy as np\n", + "import shapely.geometry as sgeom\n", + "import tensorflow as tf\n", + "import tensorflow_hub\n", + "\n", + "from google.colab import auth" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lcndQrFmjPuT" + }, + "source": [ + "## Dataset location" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "htgPqZQMJ3hZ" + }, + "outputs": [], + "source": [ + "# This Google Cloud Storage (GCS) bucket is free to access and contains an\n", + "# example subset of the full dataset (just the first shard of each\n", + "# split/variant):\n", + "EXAMPLE_DATASET_BUCKET_PATH = \"gs://dm-nowcasting-example-data/datasets/nowcasting_open_source_osgb/nimrod_osgb_1000m_yearly_splits/radar/20200718\"\n", + "\n", + "# This bucket is requester-pays and will require authentication. It contains the\n", + "# full dataset. We recommend downloading a local copy first and updating\n", + "# ROOT_DATASET_DIR below to the local path. This should save on transfer costs\n", + "# and speed up training.\n", + "FULL_DATASET_BUCKET_PATH = \"gs://dm-nowcasting/datasets/nowcasting_open_source_osgb/nimrod_osgb_1000m_yearly_splits/radar/20200718\"\n", + "\n", + "# Update this as required:\n", + "DATASET_ROOT_DIR = EXAMPLE_DATASET_BUCKET_PATH" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i3JWo990gxN8" + }, + "source": [ + "Use this to authenticate as required for access to GCS buckets:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-kxuWoPmg_fs" + }, + "outputs": [], + "source": [ + "auth.authenticate_user()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t4C0NDbfT0t9" + }, + "source": [ + "## Dataset reader code\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kKn1BcQuJsKd" + }, + "outputs": [], + "source": [ + "_FEATURES = {name: tf.io.FixedLenFeature([], dtype)\n", + " for name, dtype in [\n", + " (\"radar\", tf.string), (\"sample_prob\", tf.float32),\n", + " (\"osgb_extent_top\", tf.int64), (\"osgb_extent_left\", tf.int64),\n", + " (\"osgb_extent_bottom\", tf.int64), (\"osgb_extent_right\", tf.int64),\n", + " (\"end_time_timestamp\", tf.int64),\n", + " ]}\n", + "\n", + "_SHAPE_BY_SPLIT_VARIANT = {\n", + " (\"train\", \"random_crops_256\"): (24, 256, 256, 1),\n", + " (\"valid\", \"subsampled_tiles_256_20min_stride\"): (24, 256, 256, 1),\n", + " (\"test\", \"full_frame_20min_stride\"): (24, 1536, 1280, 1),\n", + " (\"test\", \"subsampled_overlapping_padded_tiles_512_20min_stride\"): (24, 512, 512, 1),\n", + "}\n", + "\n", + "_MM_PER_HOUR_INCREMENT = 1/32.\n", + "_MAX_MM_PER_HOUR = 128.\n", + "_INT16_MASK_VALUE = -1\n", + "\n", + "\n", + "def parse_and_preprocess_row(row, split, variant):\n", + " result = tf.io.parse_example(row, _FEATURES)\n", + " shape = _SHAPE_BY_SPLIT_VARIANT[(split, variant)]\n", + " radar_bytes = result.pop(\"radar\")\n", + " radar_int16 = tf.reshape(tf.io.decode_raw(radar_bytes, tf.int16), shape)\n", + " mask = tf.not_equal(radar_int16, _INT16_MASK_VALUE)\n", + " radar = tf.cast(radar_int16, tf.float32) * _MM_PER_HOUR_INCREMENT\n", + " radar = tf.clip_by_value(\n", + " radar, _INT16_MASK_VALUE * _MM_PER_HOUR_INCREMENT, _MAX_MM_PER_HOUR)\n", + " result[\"radar_frames\"] = radar\n", + " result[\"radar_mask\"] = mask\n", + " return result\n", + "\n", + "\n", + "def reader(split=\"train\", variant=\"random_crops_256\", shuffle_files=False):\n", + " \"\"\"Reader for open-source nowcasting datasets.\n", + " \n", + " Args:\n", + " split: Which yearly split of the dataset to use:\n", + " \"train\": Data from 2016 - 2018, excluding the first day of each month.\n", + " \"valid\": Data from 2016 - 2018, only the first day of the month.\n", + " \"test\": Data from 2019.\n", + " variant: Which variant to use. The available variants depend on the split:\n", + " \"random_crops_256\": Available for the training split. 24x256x256 pixel\n", + " crops, sampled with a bias towards crops containing rainfall. Crops at\n", + " all spatial and temporal offsets were able to be sampled, some crops may\n", + " overlap.\n", + " \"subsampled_tiles_256_20min_stride\": Available for the validation set.\n", + " Non-spatially-overlapping 24x256x256 pixel crops, subsampled from a\n", + " regular spatial grid with stride 256x256 pixels, and a temporal stride\n", + " of 20mins (4 timesteps at 5 minute resolution). Sampling favours crops\n", + " containing rainfall.\n", + " \"subsampled_overlapping_padded_tiles_512_20min_stride\": Available for the\n", + " test set. Overlapping 24x512x512 pixel crops, subsampled from a\n", + " regular spatial grid with stride 64x64 pixels, and a temporal stride\n", + " of 20mins (4 timesteps at 5 minute resolution). Subsampling favours\n", + " crops containing rainfall.\n", + " These crops include extra spatial context for a fairer evaluation of\n", + " the PySTEPS baseline, which benefits from this extra context. Our other\n", + " models only use the central 256x256 pixels of these crops.\n", + " \"full_frame_20min_stride\": Available for the test set. Includes full\n", + " frames at 24x1536x1280 pixels, every 20 minutes with no additional\n", + " subsampling.\n", + " shuffle_files: Whether to shuffle the shard files of the dataset\n", + " non-deterministically before interleaving them. Recommended for the\n", + " training set to improve mixing and read performance (since\n", + " non-deterministic parallel interleave is then enabled).\n", + "\n", + " Returns:\n", + " A tf.data.Dataset whose rows are dicts with the following keys:\n", + "\n", + " \"radar_frames\": Shape TxHxWx1, float32. Radar-based estimates of\n", + " ground-level precipitation, in units of mm/hr. Pixels which are masked\n", + " will take on a value of -1/32 and should be excluded from use as\n", + " evaluation targets. The coordinate reference system used is OSGB36, with\n", + " a spatial resolution of 1000 OSGB36 coordinate units (approximately equal\n", + " to 1km). The temporal resolution is 5 minutes.\n", + " \"radar_mask\": Shape TxHxWx1, bool. A binary mask which is False\n", + " for pixels that are unobserved / unable to be inferred from radar\n", + " measurements (e.g. due to being too far from a radar site). This mask\n", + " is usually static over time, but occasionally a whole radar site will\n", + " drop in or out resulting in large changes to the mask, and more localised\n", + " changes can happen too. \n", + " \"sample_prob\": Scalar float. The probability with which the row was\n", + " sampled from the overall pool available for sampling, as described above\n", + " under 'variants'. We use importance weights proportional to 1/sample_prob\n", + " when computing metrics on the validation and test set, to reduce bias due\n", + " to the subsampling.\n", + " \"end_time_timestamp\": Scalar int64. A timestamp for the final frame in\n", + " the example, in seconds since the UNIX epoch (1970-01-01 00:00:00 UTC).\n", + " \"osgb_extent_left\", \"osgb_extent_right\", \"osgb_extent_top\",\n", + " \"osgb_extent_bottom\":\n", + " Scalar int64s. Spatial extent for the crop in the OSGB36 coordinate\n", + " reference system.\n", + " \"\"\"\n", + " shards_glob = os.path.join(DATASET_ROOT_DIR, split, variant, \"*.tfrecord.gz\")\n", + " shard_paths = tf.io.gfile.glob(shards_glob)\n", + " shards_dataset = tf.data.Dataset.from_tensor_slices(shard_paths)\n", + " if shuffle_files:\n", + " shards_dataset = shards_dataset.shuffle(buffer_size=len(shard_paths))\n", + " return (\n", + " shards_dataset\n", + " .interleave(lambda x: tf.data.TFRecordDataset(x, compression_type=\"GZIP\"),\n", + " num_parallel_calls=tf.data.AUTOTUNE,\n", + " deterministic=not shuffle_files)\n", + " .map(lambda row: parse_and_preprocess_row(row, split, variant),\n", + " num_parallel_calls=tf.data.AUTOTUNE)\n", + " # Do your own subsequent repeat, shuffle, batch, prefetch etc as required.\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iiC7oGKRlNaj" + }, + "source": [ + "## Dataset reader documentation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "amkOoDKqlEll", + "outputId": "3141d2cd-fe88-452a-c9f8-c6b2eb449068" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Help on function reader in module __main__:\n", + "\n", + "reader(split='train', variant='random_crops_256', shuffle_files=False)\n", + " Reader for open-source nowcasting datasets.\n", + " \n", + " Args:\n", + " split: Which yearly split of the dataset to use:\n", + " \"train\": Data from 2016 - 2018, excluding the first day of each month.\n", + " \"valid\": Data from 2016 - 2018, only the first day of the month.\n", + " \"test\": Data from 2019.\n", + " variant: Which variant to use. The available variants depend on the split:\n", + " \"random_crops_256\": Available for the training split. 24x256x256 pixel\n", + " crops, sampled with a bias towards crops containing rainfall. Crops at\n", + " all spatial and temporal offsets were able to be sampled, some crops may\n", + " overlap.\n", + " \"subsampled_tiles_256_20min_stride\": Available for the validation set.\n", + " Non-spatially-overlapping 24x256x256 pixel crops, subsampled from a\n", + " regular spatial grid with stride 256x256 pixels, and a temporal stride\n", + " of 20mins (4 timesteps at 5 minute resolution). Sampling favours crops\n", + " containing rainfall.\n", + " \"subsampled_overlapping_padded_tiles_512_20min_stride\": Available for the\n", + " test set. Overlapping 24x512x512 pixel crops, subsampled from a\n", + " regular spatial grid with stride 64x64 pixels, and a temporal stride\n", + " of 20mins (4 timesteps at 5 minute resolution). Subsampling favours\n", + " crops containing rainfall.\n", + " These crops include extra spatial context for a fairer evaluation of\n", + " the PySTEPS baseline, which benefits from this extra context. Our other\n", + " models only use the central 256x256 pixels of these crops.\n", + " \"full_frame_20min_stride\": Available for the test set. Includes full\n", + " frames at 24x1536x1280 pixels, every 20 minutes with no additional\n", + " subsampling.\n", + " shuffle_files: Whether to shuffle the shard files of the dataset\n", + " non-deterministically before interleaving them. Recommended for the\n", + " training set to improve mixing and read performance (since\n", + " non-deterministic parallel interleave is then enabled).\n", + " \n", + " Returns:\n", + " A tf.data.Dataset whose rows are dicts with the following keys:\n", + " \n", + " \"radar_frames\": Shape TxHxWx1, float32. Radar-based estimates of\n", + " ground-level precipitation, in units of mm/hr. Pixels which are masked\n", + " will take on a value of -1/32 and should be excluded from use as\n", + " evaluation targets. The coordinate reference system used is OSGB36, with\n", + " a spatial resolution of 1000 OSGB36 coordinate units (approximately equal\n", + " to 1km). The temporal resolution is 5 minutes.\n", + " \"radar_mask\": Shape TxHxWx1, bool. A binary mask which is False\n", + " for pixels that are unobserved / unable to be inferred from radar\n", + " measurements (e.g. due to being too far from a radar site). This mask\n", + " is usually static over time, but occasionally a whole radar site will\n", + " drop in or out resulting in large changes to the mask, and more localised\n", + " changes can happen too. \n", + " \"sample_prob\": Scalar float. The probability with which the row was\n", + " sampled from the overall pool available for sampling, as described above\n", + " under 'variants'. We use importance weights proportional to 1/sample_prob\n", + " when computing metrics on the validation and test set, to reduce bias due\n", + " to the subsampling.\n", + " \"end_time_timestamp\": Scalar int64. A timestamp for the final frame in\n", + " the example, in seconds since the UNIX epoch (1970-01-01 00:00:00 UTC).\n", + " \"osgb_extent_left\", \"osgb_extent_right\", \"osgb_extent_top\",\n", + " \"osgb_extent_bottom\":\n", + " Scalar int64s. Spatial extent for the crop in the OSGB36 coordinate\n", + " reference system.\n", + "\n" + ] + } + ], + "source": [ + "help(reader)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QqZfqhvoTImW" + }, + "source": [ + "## Reading a row from the training set and inspecting types/shapes/values" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "r3YPkBVWX6_q" + }, + "outputs": [], + "source": [ + "row = next(iter(reader(split=\"train\", variant=\"random_crops_256\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Z__EMLrRX_Oc" + }, + "outputs": [], + "source": [ + "{k: (v.dtype, v.shape) for k, v in row.items()}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "byuVVcSnXz4q" + }, + "source": [ + "Values for scalar features:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "2i8X9KYXXtJi", + "outputId": "5b8a8980-f891-467d-994e-ccb28d9f1a16" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'end_time_timestamp': 1514725200,\n", + " 'osgb_extent_bottom': 555000,\n", + " 'osgb_extent_left': -9000,\n", + " 'osgb_extent_right': 247000,\n", + " 'osgb_extent_top': 811000,\n", + " 'sample_prob': 9.889281e-06}" + ] + }, + "execution_count": 10, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "{k: v.numpy() for k, v in row.items() if v.shape.ndims == 0}" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "10_lz7FYZ6sl" + }, + "source": [ + "Decoding the end_time_timestamp:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 35 + }, + "id": "WtAV24gyZ5jL", + "outputId": "cf8628ed-b6eb-4754-cce8-f1ff8bad14bf" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'2017-12-31T13:00:00'" + ] + }, + "execution_count": 11, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "datetime.datetime.utcfromtimestamp(row[\"end_time_timestamp\"]).isoformat()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "e4HPIMX1VV5-" + }, + "source": [ + "## Visualization helpers" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ieLdf614RxTu" + }, + "outputs": [], + "source": [ + "matplotlib.rc('animation', html='jshtml')\n", + "\n", + "\n", + "def plot_animation(field, figsize=None,\n", + " vmin=0, vmax=10, cmap=\"jet\", **imshow_args):\n", + " fig = plt.figure(figsize=figsize)\n", + " ax = plt.axes()\n", + " ax.set_axis_off()\n", + " plt.close() # Prevents extra axes being plotted below animation\n", + " img = ax.imshow(field[0, ..., 0], vmin=vmin, vmax=vmax, cmap=cmap, **imshow_args)\n", + "\n", + " def animate(i):\n", + " img.set_data(field[i, ..., 0])\n", + " return (img,)\n", + "\n", + " return animation.FuncAnimation(\n", + " fig, animate, frames=field.shape[0], interval=24, blit=False)\n", + "\n", + "\n", + "class ExtendedOSGB(cartopy.crs.OSGB):\n", + " \"\"\"MET office radar data uses OSGB36 with an extended bounding box.\"\"\"\n", + "\n", + " def __init__(self):\n", + " super().__init__(approx=False)\n", + "\n", + " @property\n", + " def x_limits(self):\n", + " return (-405000, 1320000)\n", + "\n", + " @property\n", + " def y_limits(self):\n", + " return (-625000, 1550000)\n", + "\n", + " @property\n", + " def boundary(self):\n", + " x0, x1 = self.x_limits\n", + " y0, y1 = self.y_limits\n", + " return sgeom.LinearRing([(x0, y0), (x0, y1), (x1, y1), (x1, y0), (x0, y0)])\n", + "\n", + "\n", + "def plot_rows_on_map(rows, field_name=\"radar_frames\", timestep=0, num_rows=None,\n", + " cbar_label=None, **imshow_kwargs):\n", + " fig = plt.figure(figsize=(10, 10))\n", + " axes = fig.add_subplot(1, 1, 1, projection=ExtendedOSGB())\n", + " if num_rows is None:\n", + " num_rows = next(iter(rows.values())).shape[0]\n", + " for b in range(num_rows):\n", + " extent = (rows[\"osgb_extent_left\"][b].numpy(),\n", + " rows[\"osgb_extent_right\"][b].numpy(),\n", + " rows[\"osgb_extent_bottom\"][b].numpy(),\n", + " rows[\"osgb_extent_top\"][b].numpy())\n", + " im = axes.imshow(rows[field_name][b, timestep, ..., 0].numpy(),\n", + " extent=extent, **imshow_kwargs)\n", + "\n", + " axes.set_xlim(*axes.projection.x_limits)\n", + " axes.set_ylim(*axes.projection.y_limits)\n", + " axes.set_facecolor(\"black\")\n", + " axes.gridlines(alpha=0.5)\n", + " axes.coastlines(resolution=\"50m\", color=\"white\")\n", + " if cbar_label:\n", + " cbar = fig.colorbar(im)\n", + " cbar.set_label(cbar_label)\n", + " return fig\n", + "\n", + "\n", + "def plot_animation_on_map(row):\n", + " fig = plt.figure(figsize=(10, 10))\n", + " axes = fig.add_subplot(1, 1, 1, projection=ExtendedOSGB())\n", + " plt.close() # Prevents extra axes being plotted below animation\n", + "\n", + " axes.gridlines(alpha=0.5)\n", + " axes.coastlines(resolution=\"50m\", color=\"white\")\n", + "\n", + " extent = (row[\"osgb_extent_left\"].numpy(),\n", + " row[\"osgb_extent_right\"].numpy(),\n", + " row[\"osgb_extent_bottom\"].numpy(),\n", + " row[\"osgb_extent_top\"].numpy())\n", + "\n", + " img = axes.imshow(\n", + " row[\"radar_frames\"][0, ..., 0].numpy(),\n", + " extent=extent, vmin=0, vmax=15., cmap=\"jet\")\n", + "\n", + " cbar = fig.colorbar(img)\n", + " cbar.set_label(\"Precipitation, mm/hr\")\n", + "\n", + " def animate(i):\n", + " return img.set_data(row[\"radar_frames\"][i, ..., 0].numpy()),\n", + "\n", + " return animation.FuncAnimation(\n", + " fig, animate, frames=row[\"radar_frames\"].shape[0],\n", + " interval=24, blit=False)\n", + "\n", + "\n", + "def plot_mask_on_map(row):\n", + " fig = plt.figure(figsize=(10, 10))\n", + " axes = fig.add_subplot(1, 1, 1, projection=ExtendedOSGB())\n", + " axes.gridlines(alpha=0.5)\n", + " axes.coastlines(resolution=\"50m\", color=\"black\")\n", + "\n", + " extent = (row[\"osgb_extent_left\"].numpy(),\n", + " row[\"osgb_extent_right\"].numpy(),\n", + " row[\"osgb_extent_bottom\"].numpy(),\n", + " row[\"osgb_extent_top\"].numpy())\n", + "\n", + " img = axes.imshow(\n", + " row[\"radar_mask\"][0, ..., 0].numpy(),\n", + " extent=extent, vmin=0, vmax=1, cmap=\"viridis\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZrfahN2wZS_G" + }, + "source": [ + "## Visualizing rows" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pn6C9qi4ZcYl" + }, + "source": [ + "Animation of a single row from the random_crops_256 training set (sequence of 24 frames at 256x256)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "J4_aB6vbXB0A" + }, + "outputs": [], + "source": [ + "plot_animation(row[\"radar_frames\"].numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fsDrJOCwZtZv" + }, + "source": [ + "And its mask. This may not always be interesting, sometimes it will be all ones. I only plot the first frame as this is usually static over time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LtqAPI2utIDJ" + }, + "outputs": [], + "source": [ + "plt.imshow(row[\"radar_mask\"][0, ..., 0].numpy(), vmin=0, vmax=1);" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jKalHlONmmTF" + }, + "source": [ + "Plotting an animation of a row from the full-frame test set" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gBZWTLFEbbal" + }, + "outputs": [], + "source": [ + "dataset = reader(split=\"test\", variant=\"full_frame_20min_stride\")\n", + "full_frame_test_set_row = next(iter(dataset))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qUTqwhC0gmIm" + }, + "outputs": [], + "source": [ + "plot_animation_on_map(full_frame_test_set_row)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6Go5BVATqVwd" + }, + "source": [ + "And just its mask:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "T7cgRtRilKfv" + }, + "outputs": [], + "source": [ + "plot_mask_on_map(full_frame_test_set_row)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KtXg3lJxa1_q" + }, + "source": [ + "Plotting a few different crops from the training set on the same map, using their OSGB extents. Note these will have been sampled at different timestamps so won't be consistent with each other. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "x7HXOz7PfdtH" + }, + "outputs": [], + "source": [ + "BATCH_SIZE = 60\n", + "dataset = reader(split=\"train\", variant=\"random_crops_256\")\n", + "rows = next(iter(dataset.batch(BATCH_SIZE)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "76plhdrhaz60" + }, + "outputs": [], + "source": [ + "plot_rows_on_map(rows, field_name=\"radar_frames\", num_rows=10, vmin=0, vmax=15.,\n", + " cmap=\"jet\", cbar_label=\"Precipitation, mm/hr\");" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YwliM5nOaivq" + }, + "source": [ + "And plotting their masks, which will be more consistent with each other since they change less frequently." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rDUNaca0yXrC" + }, + "outputs": [], + "source": [ + "plot_rows_on_map(rows, field_name=\"radar_mask\", vmin=0, vmax=1, alpha=0.5, cmap=\"spring\");" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-KxxSOvwsUaP" + }, + "source": [ + "## Making predictions using model loaded from TF-Hub snapshots" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IvWq8_4uvRBb" + }, + "source": [ + "Location of snapshots:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GZfhJ2orvWCV" + }, + "outputs": [], + "source": [ + "TFHUB_BASE_PATH = \"gs://dm-nowcasting-example-data/tfhub_snapshots\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aFLv08o5vX7y" + }, + "source": [ + "### Helper code for loading snapshots and making predictions with them" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zOFI9xQztNIW" + }, + "outputs": [], + "source": [ + "def load_module(input_height, input_width):\n", + " \"\"\"Load a TF-Hub snapshot of the 'Generative Method' model.\"\"\"\n", + " hub_module = tensorflow_hub.load(\n", + " os.path.join(TFHUB_BASE_PATH, f\"{input_height}x{input_width}\"))\n", + " # Note this has loaded a legacy TF1 model for running under TF2 eager mode.\n", + " # This means we need to access the module via the \"signatures\" attribute. See\n", + " # https://github.com/tensorflow/hub/blob/master/docs/migration_tf2.md#using-lower-level-apis\n", + " # for more information.\n", + " return hub_module.signatures['default']\n", + "\n", + "\n", + "def predict(module, input_frames, num_samples=1,\n", + " include_input_frames_in_result=False):\n", + " \"\"\"Make predictions from a TF-Hub snapshot of the 'Generative Method' model.\n", + "\n", + " Args:\n", + " module: One of the raw TF-Hub modules returned by load_module above.\n", + " input_frames: Shape (T_in,H,W,C), where T_in = 4. Input frames to condition\n", + " the predictions on.\n", + " num_samples: The number of different samples to draw.\n", + " include_input_frames_in_result: If True, will return a total of 22 frames\n", + " along the time axis, the 4 input frames followed by 18 predicted frames.\n", + " Otherwise will only return the 18 predicted frames.\n", + "\n", + " Returns:\n", + " A tensor of shape (num_samples,T_out,H,W,C), where T_out is either 18 or 22\n", + " as described above.\n", + " \"\"\"\n", + " input_frames = tf.math.maximum(input_frames, 0.)\n", + " # Add a batch dimension and tile along it to create a copy of the input for\n", + " # each sample:\n", + " input_frames = tf.expand_dims(input_frames, 0)\n", + " input_frames = tf.tile(input_frames, multiples=[num_samples, 1, 1, 1, 1])\n", + "\n", + " # Sample the latent vector z for each sample:\n", + " _, input_signature = module.structured_input_signature\n", + " z_size = input_signature['z'].shape[1]\n", + " z_samples = tf.random.normal(shape=(num_samples, z_size))\n", + "\n", + " inputs = {\n", + " \"z\": z_samples,\n", + " \"labels$onehot\" : tf.ones(shape=(num_samples, 1)),\n", + " \"labels$cond_frames\" : input_frames\n", + " }\n", + " samples = module(**inputs)['default']\n", + " if not include_input_frames_in_result:\n", + " # The module returns the input frames alongside its sampled predictions, we\n", + " # slice out just the predictions:\n", + " samples = samples[:, NUM_INPUT_FRAMES:, ...]\n", + "\n", + " # Take positive values of rainfall only.\n", + " samples = tf.math.maximum(samples, 0.)\n", + " return samples\n", + "\n", + "\n", + "# Fixed values supported by the snapshotted model.\n", + "NUM_INPUT_FRAMES = 4\n", + "NUM_TARGET_FRAMES = 18\n", + "\n", + "\n", + "def extract_input_and_target_frames(radar_frames):\n", + " \"\"\"Extract input and target frames from a dataset row's radar_frames.\"\"\"\n", + " # We align our targets to the end of the window, and inputs precede targets.\n", + " input_frames = radar_frames[-NUM_TARGET_FRAMES-NUM_INPUT_FRAMES : -NUM_TARGET_FRAMES]\n", + " target_frames = radar_frames[-NUM_TARGET_FRAMES : ]\n", + " return input_frames, target_frames\n", + "\n", + "\n", + "def horizontally_concatenate_batch(samples):\n", + " n, t, h, w, c = samples.shape\n", + " # N,T,H,W,C =\u003e T,H,N,W,C =\u003e T,H,N*W,C\n", + " return tf.reshape(tf.transpose(samples, [1, 2, 0, 3, 4]), [t, h, n*w, c])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YuTsqBqM1PeO" + }, + "source": [ + "### Making predictions for a row from the validation set (256x256 crops)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kX6wgACVtcHz" + }, + "outputs": [], + "source": [ + "module = load_module(256, 256)\n", + "row = next(iter(reader(split=\"valid\", variant=\"subsampled_tiles_256_20min_stride\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LKBlYh0qx7nL" + }, + "outputs": [], + "source": [ + "num_samples = 5\n", + "input_frames, target_frames = extract_input_and_target_frames(row[\"radar_frames\"])\n", + "samples = predict(module, input_frames,\n", + " num_samples=num_samples, include_input_frames_in_result=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NDOh7j2g65wP" + }, + "source": [ + "We will plot an animation of 5 different samples, including the input frames first (so all 5 will start the same). You can see they end up in different places." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_Ymu_GFBzGJd" + }, + "outputs": [], + "source": [ + "plot_animation(horizontally_concatenate_batch(samples), figsize=(4*num_samples, 4))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "57uh8X2XEqqj" + }, + "source": [ + "### Making predictions on a row from the full-frame test set (1536x1280)\n", + "\n", + "Warning: this will require more RAM than is available in a free public colab kernel, even if you reduce num_samples to 1." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RdnIsm725Nvk" + }, + "outputs": [], + "source": [ + "# This is the same model with same parameters as above; we have had to export\n", + "# separate copies of the graph for each input size as the input size is\n", + "# unfortunately hardcoded into the graph as static shapes.\n", + "module = load_module(1536, 1280)\n", + "\n", + "full_frame_test_set_row = next(iter(\n", + " reader(split=\"test\", variant=\"full_frame_20min_stride\")))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sLQdfet59caW" + }, + "outputs": [], + "source": [ + "num_samples = 2\n", + "input_frames, target_frames = extract_input_and_target_frames(\n", + " full_frame_test_set_row[\"radar_frames\"])\n", + "samples = predict(module, input_frames,\n", + " num_samples=num_samples, include_input_frames_in_result=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K_2UwKa_E3Xo" + }, + "source": [ + "Plotting two different predicted samples following on from the input frames. The first sample:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MytvlcaA8P4P" + }, + "outputs": [], + "source": [ + "row_with_predictions = full_frame_test_set_row.copy()\n", + "row_with_predictions[\"radar_frames\"] = samples[0]\n", + "plot_animation_on_map(row_with_predictions)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AznUbE6-E53h" + }, + "source": [ + "And the second sample:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Er86IKdm-W5l" + }, + "outputs": [], + "source": [ + "row_with_predictions[\"radar_frames\"] = samples[1]\n", + "plot_animation_on_map(row_with_predictions)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mnNy6Uurp1_k" + }, + "source": [ + "The ground truth, for comparison, was plotted earlier as an example row from the test set." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "b9v4xCWqqdzV" + }, + "outputs": [], + "source": [ + "" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "EE0_Q3CXv_wH", + "t4C0NDbfT0t9", + "e4HPIMX1VV5-" + ], + "name": "Open-sourced dataset and model snapshot for precipitation nowcasting", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/nowcasting/README.md b/nowcasting/README.md new file mode 100644 index 0000000..5c1152c --- /dev/null +++ b/nowcasting/README.md @@ -0,0 +1,37 @@ +# Skillful Precipitation Nowcasting Using Deep Generative Models of Radar + +This repository is a supplement to "Skillful Precipitation Nowcasting using Deep +Generative Models of Radar" and provides necessary code for loading data from a +large scale nowcasting dataset and obtaining predictions with the pretrained +model. + +Please see the Colab notebook for further details: + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepmind/deepmind-research/nowcasting/blob/master/Open_sourced_dataset_and_model_snapshot_for_precipitation_nowcasting.ipynb) + +## License + +The Colab notebook is licensed under the Apache License, Version 2.0. The +associated model snapshots are made available for use under the terms of the +[Creative Commons Attribution 4.0 International License][cc-by]. + +The provided post-processed nowcasting dataset is licensed under a +[Creative Commons Attribution 4.0 International License][cc-by] and it contains +public sector information licensed by the Met Office under the +[Open Government Licence v3.0][open-govt-license]. + +## Citation + +If you use this work, consider citing our paper: + +```latex +@article{ravuris2021skillful, + author={Suman Ravuri and Karel Lenc and Matthew Willson and Dmitry Kangin and Remi Lam and Piotr Mirowski and Megan Fitzsimons and Maria Athanassiadou and Sheleem Kashem and Sam Madge and Rachel Prudden Amol Mandhane and Aidan Clark and Andrew Brock and Karen Simonyan and Raia Hadsell and Niall Robinson Ellen Clancy and Alberto Arribas† and Shakir Mohamed}, + title={Skillful Precipitation Nowcasting using Deep Generative Models of Radar}, + journal={}, + year={2021} +} +``` + +[cc-by]: http://creativecommons.org/licenses/by/4.0/ +[open-govt-license]: http://www.nationalarchives.gov.uk/doc/open-government-licence/version/3/