Add OGB-LSC code.

PiperOrigin-RevId: 379591158
This commit is contained in:
Alvaro Sanchez-Gonzalez
2021-06-15 23:45:45 +01:00
committed by Saran Tunyasuvunakool
parent 88f674ad87
commit a3c133404e
40 changed files with 6934 additions and 1 deletions
+1 -1
View File
@@ -24,6 +24,7 @@ https://deepmind.com/research/publications/
## Projects
* [Open Graph Benchmark - Large-Scale Challenge (OGB-LSC)](ogb_lsc)
* [Synthetic Returns for Long-Term Credit Assignment](synthetic_returns)
* [A Deep Learning Approach for Characterizing Major Galaxy Mergers](galaxy_mergers)
* [Better, Faster Fermionic Neural Networks](kfac_ferminet_alpha) (KFAC implementation)
@@ -73,7 +74,6 @@ https://deepmind.com/research/publications/
## Disclaimer
*This is not an official Google product.*
+29
View File
@@ -0,0 +1,29 @@
# DeepMind entry for PCQM4M-LSC
This repository contains DeepMind's entry to the [PCWM4M-LSC](https://ogb.stanford.edu/kddcup2021/pcqm4m/) (quantum chemistry) and
[MAG240M-LSC](https://ogb.stanford.edu/kddcup2021/mag240m/) (academic graph)
tracks of the [OGB Large-Scale Challenge](https://ogb.stanford.edu/kddcup2021/)
(OGB-LSC).
For full details regarding this entry, please see our [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf).
## Code structure
* `pcq/`: Scripts for training, evaluating on the [PCQ dataset](https://ogb.stanford.edu/docs/graphprop/).
* `mag/`: Scripts for training, evaluating on the [MAG dataset](https://ogb.stanford.edu/docs/nodeprop/).
## Citation
To cite this work:
```latex
@article{deepmind2021ogb,
author = {Ravichandra Addanki and Peter Battaglia and David Budden and Andreea
Deac and Jonathan Godwin and Thomas Keck and Wai Lok Sibon Li and Alvaro
Sanchez-Gonzalez and Jacklynn Stott and Shantanu Thakoor and Petar
Veli\v{c}kovi\'{c}},
title = {Large-scale graph representation learning with very deep GNNs and
self-supervision},
year = {2021},
}
```
+128
View File
@@ -0,0 +1,128 @@
# DeepMind entry for MAG240M-LSC
This repository contains DeepMind's entry to the [MAG240M-LSC](https://ogb.stanford.edu/kddcup2021/mag240m/) (academic graph) track of the
[OGB Large-Scale Challenge](https://ogb.stanford.edu/kddcup2021/) (OGB-LSC).
For full details regarding this entry, please see our [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf).
## DeepMind MAG Team ("Academic")
(in alphabetical order)
- Ravichandra Addanki
- Peter Battaglia
- David Budden
- Andreea Deac
- Jonathan Godwin
- Thomas Keck
- Alvaro Sanchez-Gonzalez
- Jacklynn Stott
- Shantanu Thakoor
- Petar Veličković
## Performance
Our final test set performance was achieved by pooling an ensemble of 10 folds.
See [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf) for details.
Each model was trained for < 72 hours using 4x Google Cloud TPUv4 and 1x AMD
EPYC 7B12 64-core CPU @2.25GHz.
Inference takes < 12 hours on 4x NVIDIA V100 16GB GPU and 1x Intel Xeon Gold
6148 20-core CPU @2.40GHz.
# Running our model
## Setup
You can set up Python virtual environment (you might need to install the
`python3-venv` package first) with all needed dependencies inside the forked
`deepmind_research` repository using:
```bash
python3 -m venv /tmp/mag_venv
source /tmp/mag_venv/bin/activate
pip3 install --upgrade pip setuptools wheel
pip3 install -r ogb_lsc/mag/requirements.txt
```
## Download and pre-process data
**1. Download the dataset using the contest toolkit ([here](https://ogb.stanford.edu/kddcup2021/mag240m/#dataset)) to a local directory
`ROOT`.**
**2. Run this script to reorganize the data into a flat directory structure with
transparent names**
```bash
/bin/bash organize_data.sh -r ROOT
```
Once this completes, a new directory `ROOT/mag240m_kddcup2021/raw` will be
created, with contents:
- `node_feat.npy`
- `node_label.npy`
- `node_year.npy`
- `author_affiliated_with_institution_edges.npy`
- `author_writes_paper_edges.npy`
- `paper_cites_paper_edges.npy`
- `train_idx.npy`
- `valid_idx.npy`
- `test_idx.npy`
We refer to this as the "raw" data.
**3. To run the preprocessing code**
```bash
/bin/bash run_preprocessing.sh -r ROOT
```
The pre-processing is very time- and memory-consuming, and should only be run
to verify the full pipeline. You can download the pre-processed data using the
following script, for use in training and evaluating models:
```bash
python3 download_mag.py --task_root=${HOME}/mag --payload="data"
```
## Reproducing our final results
We have provided pre-trained weights of our final submission for convenience.
They can be downloaded with:
```
python3 download_mag.py --task_root=${HOME}/mag --payload="models"
```
Then to reproduce our final results, please run `bash run_pretrain_eval.sh`.
## Retraining our model
Disclaimer: This script is provided for illustrative purposes. It is not
practical for actual training since it only uses a single machine, and likely
requires reducing the batch size and/or model size to fit on a single GPU.
If you still want to train a model, please run `run_training.sh`. To simply
validate that the code is running correctly on your hardware setup, consider
setting `debug=True` in `config.py`, which trains a smaller model.
# Citation
To cite this work (together with our PCQM4M-LSC entry):
```latex
@article{deepmind2021ogb,
author = {Ravichandra Addanki and Peter Battaglia and David Budden and Andreea
Deac and Jonathan Godwin and Thomas Keck and Wai Lok Sibon Li and Alvaro
Sanchez-Gonzalez and Jacklynn Stott and Shantanu Thakoor and Petar
Veli\v{c}kovi\'{c}},
title = {Large-scale graph representation learning with very deep GNNs and
self-supervision},
year = {2021},
}
```
Our technical report can be found [here](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf).
+138
View File
@@ -0,0 +1,138 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dynamic batching utilities."""
from typing import Generator, Iterable, Iterator, Sequence, Tuple
import jax.tree_util as tree
import jraph
import numpy as np
_NUMBER_FIELDS = ("n_node", "n_edge", "n_graph")
def dynamically_batch(graphs_tuple_iterator: Iterator[jraph.GraphsTuple],
n_node: int, n_edge: int,
n_graph: int) -> Generator[jraph.GraphsTuple, None, None]:
"""Dynamically batches trees with `jraph.GraphsTuples` to `graph_batch_size`.
Elements of the `graphs_tuple_iterator` will be incrementally added to a batch
until the limits defined by `n_node`, `n_edge` and `n_graph` are reached. This
means each element yielded by this generator
For situations where you have variable sized data, it"s useful to be able to
have variable sized batches. This is especially the case if you have a loss
defined on the variable shaped element (for example, nodes in a graph).
Args:
graphs_tuple_iterator: An iterator of `jraph.GraphsTuples`.
n_node: The maximum number of nodes in a batch.
n_edge: The maximum number of edges in a batch.
n_graph: The maximum number of graphs in a batch.
Yields:
A `jraph.GraphsTuple` batch of graphs.
Raises:
ValueError: if the number of graphs is < 2.
RuntimeError: if the `graphs_tuple_iterator` contains elements which are not
`jraph.GraphsTuple`s.
RuntimeError: if a graph is found which is larger than the batch size.
"""
if n_graph < 2:
raise ValueError("The number of graphs in a batch size must be greater or "
f"equal to `2` for padding with graphs, got {n_graph}.")
valid_batch_size = (n_node - 1, n_edge, n_graph - 1)
accumulated_graphs = []
num_accumulated_nodes = 0
num_accumulated_edges = 0
num_accumulated_graphs = 0
for element in graphs_tuple_iterator:
element_nodes, element_edges, element_graphs = _get_graph_size(element)
if _is_over_batch_size(element, valid_batch_size):
graph_size = element_nodes, element_edges, element_graphs
graph_size = {k: v for k, v in zip(_NUMBER_FIELDS, graph_size)}
batch_size = {k: v for k, v in zip(_NUMBER_FIELDS, valid_batch_size)}
raise RuntimeError("Found graph bigger than batch size. Valid Batch "
f"Size: {batch_size}, Graph Size: {graph_size}")
if not accumulated_graphs:
# If this is the first element of the batch, set it and continue.
accumulated_graphs = [element]
num_accumulated_nodes = element_nodes
num_accumulated_edges = element_edges
num_accumulated_graphs = element_graphs
continue
else:
# Otherwise check if there is space for the graph in the batch:
if ((num_accumulated_graphs + element_graphs > n_graph - 1) or
(num_accumulated_nodes + element_nodes > n_node - 1) or
(num_accumulated_edges + element_edges > n_edge)):
# If there is, add it to the batch
batched_graph = _batch_np(accumulated_graphs)
yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph)
accumulated_graphs = [element]
num_accumulated_nodes = element_nodes
num_accumulated_edges = element_edges
num_accumulated_graphs = element_graphs
else:
# Otherwise, return the old batch and start a new batch.
accumulated_graphs.append(element)
num_accumulated_nodes += element_nodes
num_accumulated_edges += element_edges
num_accumulated_graphs += element_graphs
# We may still have data in batched graph.
if accumulated_graphs:
batched_graph = _batch_np(accumulated_graphs)
yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph)
def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple:
# Calculates offsets for sender and receiver arrays, caused by concatenating
# the nodes arrays.
offsets = np.cumsum(np.array([0] + [np.sum(g.n_node) for g in graphs[:-1]]))
def _map_concat(nests):
concat = lambda *args: np.concatenate(args)
return tree.tree_multimap(concat, *nests)
return jraph.GraphsTuple(
n_node=np.concatenate([g.n_node for g in graphs]),
n_edge=np.concatenate([g.n_edge for g in graphs]),
nodes=_map_concat([g.nodes for g in graphs]),
edges=_map_concat([g.edges for g in graphs]),
globals=_map_concat([g.globals for g in graphs]),
senders=np.concatenate([g.senders + o for g, o in zip(graphs, offsets)]),
receivers=np.concatenate(
[g.receivers + o for g, o in zip(graphs, offsets)]))
def _get_graph_size(graph: jraph.GraphsTuple) -> Tuple[int, int, int]:
n_node = np.sum(graph.n_node)
n_edge = len(graph.senders)
n_graph = len(graph.n_node)
return n_node, n_edge, n_graph
def _is_over_batch_size(
graph: jraph.GraphsTuple,
graph_batch_size: Iterable[int],
) -> bool:
graph_size = _get_graph_size(graph)
return any([x > y for x, y in zip(graph_size, graph_batch_size)])
+117
View File
@@ -0,0 +1,117 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experiment config for MAG240M-LSC entry."""
from jaxline import base_config
from ml_collections import config_dict
def get_config(debug: bool = False) -> config_dict.ConfigDict:
"""Get Jaxline experiment config."""
config = base_config.get_base_config()
config.random_seed = 42
# E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below)
config.restore_path = config_dict.placeholder(str)
config.experiment_kwargs = config_dict.ConfigDict(
dict(
config=dict(
debug=debug,
predictions_dir=config_dict.placeholder(str),
# 5 for model selection and early stopping, 50 for final eval.
num_eval_iterations_to_ensemble=5,
dataset_kwargs=dict(
data_root='/data/',
online_subsampling_kwargs=dict(
max_nb_neighbours_per_type=[
[[40, 20, 0, 40], [0, 0, 0, 0], [0, 0, 0, 0]],
[[40, 20, 0, 40], [40, 0, 10, 0], [0, 0, 0, 0]],
],
remove_future_nodes=True,
deduplicate_nodes=True,
),
ratio_unlabeled_data_to_labeled_data=10.0,
k_fold_split_id=config_dict.placeholder(int),
use_all_labels_when_not_training=False,
use_dummy_adjacencies=debug,
),
optimizer=dict(
name='adamw',
kwargs=dict(weight_decay=1e-5, b1=0.9, b2=0.999),
learning_rate_schedule=dict(
use_schedule=True,
base_learning_rate=1e-2,
warmup_steps=50000,
total_steps=config.get_ref('training_steps'),
),
),
model_config=dict(
mlp_hidden_sizes=[32] if debug else [512],
latent_size=32 if debug else 256,
num_message_passing_steps=2 if debug else 4,
activation='relu',
dropout_rate=0.3,
dropedge_rate=0.25,
disable_edge_updates=True,
use_sent_edges=True,
normalization_type='layer_norm',
aggregation_function='sum',
),
training=dict(
loss_config=dict(
bgrl_loss_config=dict(
stop_gradient_for_supervised_loss=False,
bgrl_loss_scale=1.0,
symmetrize=True,
first_graph_corruption_config=dict(
feature_drop_prob=0.4,
edge_drop_prob=0.2,
),
second_graph_corruption_config=dict(
feature_drop_prob=0.4,
edge_drop_prob=0.2,
),
),
),
# GPU memory may require reducing the `256`s below to `48`.
dynamic_batch_size_config=dict(
n_node=256 if debug else 340 * 256,
n_edge=512 if debug else 720 * 256,
n_graph=4 if debug else 256,
),
),
eval=dict(
split='valid',
ema_annealing_schedule=dict(
use_schedule=True,
base_rate=0.999,
total_steps=config.get_ref('training_steps')),
dynamic_batch_size_config=dict(
n_node=256 if debug else 340 * 128,
n_edge=512 if debug else 720 * 128,
n_graph=4 if debug else 128,
),
))))
## Training loop config.
config.training_steps = 500000
config.checkpoint_dir = '/tmp/checkpoint/mag/'
config.train_checkpoint_all_hosts = False
config.log_train_data_interval = 10
config.log_tensors_interval = 10
config.save_checkpoint_interval = 30
config.best_model_eval_metric = 'accuracy'
config.best_model_eval_metric_higher_is_better = True
return config
+136
View File
@@ -0,0 +1,136 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Builds CSR matrices which store the MAG graphs."""
import pathlib
from absl import app
from absl import flags
from absl import logging
import numpy as np
import scipy.sparse
# pylint: disable=g-bad-import-order
import data_utils
Path = pathlib.Path
FLAGS = flags.FLAGS
_DATA_FILES_AND_PARAMETERS = {
'author_affiliated_with_institution_edges.npy': {
'content_names': ('author', 'institution'),
'use_boolean': False
},
'author_writes_paper_edges.npy': {
'content_names': ('author', 'paper'),
'use_boolean': False
},
'paper_cites_paper_edges.npy': {
'content_names': ('paper', 'paper'),
'use_boolean': True
},
}
flags.DEFINE_string('data_root', None, 'Data root directory')
flags.DEFINE_boolean('skip_existing', True, 'Skips existing CSR files')
flags.mark_flags_as_required(['data_root'])
def _read_edge_data(path):
try:
return np.load(path, mmap_mode='r')
except FileNotFoundError:
# If the file path can't be found by np.load, use the file handle w/o mmap.
with path.open('rb') as fid:
return np.load(fid)
def _build_coo(edges_data, use_boolean=False):
if use_boolean:
mat_coo = scipy.sparse.coo_matrix(
(np.ones_like(edges_data[1, :],
dtype=bool), (edges_data[0, :], edges_data[1, :])))
else:
mat_coo = scipy.sparse.coo_matrix(
(edges_data[1, :], (edges_data[0, :], edges_data[1, :])))
return mat_coo
def _get_output_paths(directory, content_names, use_boolean):
boolean_str = '_b' if use_boolean else ''
transpose_str = '_t' if len(set(content_names)) == 1 else ''
output_prefix = '_'.join(content_names)
output_prefix_t = '_'.join(content_names[::-1])
output_filename = f'{output_prefix}{boolean_str}.npz'
output_filename_t = f'{output_prefix_t}{boolean_str}{transpose_str}.npz'
output_path = directory / output_filename
output_path_t = directory / output_filename_t
return output_path, output_path_t
def _write_csr(path, csr):
path.parent.mkdir(parents=True, exist_ok=True)
with path.open('wb') as fid:
scipy.sparse.save_npz(fid, csr)
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
raw_data_dir = Path(FLAGS.data_root) / data_utils.RAW_DIR
preprocessed_dir = Path(FLAGS.data_root) / data_utils.PREPROCESSED_DIR
for input_filename, parameters in _DATA_FILES_AND_PARAMETERS.items():
input_path = raw_data_dir / input_filename
output_path, output_path_t = _get_output_paths(preprocessed_dir,
**parameters)
if FLAGS.skip_existing and output_path.exists() and output_path_t.exists():
# If both files exist, skip. When only one exists, that's handled below.
logging.info(
'%s and %s exist: skipping. Use flag `--skip_existing=False`'
'to force overwrite existing.', output_path, output_path_t)
continue
logging.info('Reading edge data from: %s', input_path)
edge_data = _read_edge_data(input_path)
logging.info('Building CSR matrices')
mat_coo = _build_coo(edge_data, use_boolean=parameters['use_boolean'])
# Convert matrices to CSR and write to disk.
if not FLAGS.skip_existing or not output_path.exists():
logging.info('Writing CSR matrix to: %s', output_path)
mat_csr = mat_coo.tocsr()
_write_csr(output_path, mat_csr)
del mat_csr # Free up memory asap.
else:
logging.info(
'%s exists: skipping. Use flag `--skip_existing=False`'
'to force overwrite existing.', output_path)
if not FLAGS.skip_existing or not output_path_t.exists():
logging.info('Writing (transposed) CSR matrix to: %s', output_path_t)
mat_csr_t = mat_coo.transpose().tocsr()
_write_csr(output_path_t, mat_csr_t)
del mat_csr_t # Free up memory asap.
else:
logging.info(
'%s exists: skipping. Use flag `--skip_existing=False`'
'to force overwrite existing.', output_path_t)
del mat_coo # Free up memory asap.
if __name__ == '__main__':
app.run(main)
+491
View File
@@ -0,0 +1,491 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset utilities."""
import functools
import pathlib
from typing import Dict, Tuple
from absl import logging
from graph_nets import graphs as tf_graphs
from graph_nets import utils_tf
import numpy as np
import scipy.sparse as sp
import tensorflow as tf
import tqdm
# pylint: disable=g-bad-import-order
import sub_sampler
Path = pathlib.Path
NUM_PAPERS = 121751666
NUM_AUTHORS = 122383112
NUM_INSTITUTIONS = 25721
EMBEDDING_SIZE = 768
NUM_CLASSES = 153
NUM_NODES = NUM_PAPERS + NUM_AUTHORS + NUM_INSTITUTIONS
NUM_EDGES = 1_728_364_232
assert NUM_NODES == 244_160_499
NUM_K_FOLD_SPLITS = 10
OFFSETS = {
"paper": 0,
"author": NUM_PAPERS,
"institution": NUM_PAPERS + NUM_AUTHORS,
}
SIZES = {
"paper": NUM_PAPERS,
"author": NUM_AUTHORS,
"institution": NUM_INSTITUTIONS
}
RAW_DIR = Path("raw")
PREPROCESSED_DIR = Path("preprocessed")
RAW_NODE_FEATURES_FILENAME = RAW_DIR / "node_feat.npy"
RAW_NODE_LABELS_FILENAME = RAW_DIR / "node_label.npy"
RAW_NODE_YEAR_FILENAME = RAW_DIR / "node_year.npy"
TRAIN_INDEX_FILENAME = RAW_DIR / "train_idx.npy"
VALID_INDEX_FILENAME = RAW_DIR / "train_idx.npy"
TEST_INDEX_FILENAME = RAW_DIR / "train_idx.npy"
EDGES_PAPER_PAPER_B = PREPROCESSED_DIR / "paper_paper_b.npz"
EDGES_PAPER_PAPER_B_T = PREPROCESSED_DIR / "paper_paper_b_t.npz"
EDGES_AUTHOR_INSTITUTION = PREPROCESSED_DIR / "author_institution.npz"
EDGES_INSTITUTION_AUTHOR = PREPROCESSED_DIR / "institution_author.npz"
EDGES_AUTHOR_PAPER = PREPROCESSED_DIR / "author_paper.npz"
EDGES_PAPER_AUTHOR = PREPROCESSED_DIR / "paper_author.npz"
PCA_PAPER_FEATURES_FILENAME = PREPROCESSED_DIR / "paper_feat_pca_129.npy"
PCA_AUTHOR_FEATURES_FILENAME = (
PREPROCESSED_DIR / "author_feat_from_paper_feat_pca_129.npy")
PCA_INSTITUTION_FEATURES_FILENAME = (
PREPROCESSED_DIR / "institution_feat_from_paper_feat_pca_129.npy")
PCA_MERGED_FEATURES_FILENAME = (
PREPROCESSED_DIR / "merged_feat_from_paper_feat_pca_129.npy")
NEIGHBOR_INDICES_FILENAME = PREPROCESSED_DIR / "neighbor_indices.npy"
NEIGHBOR_DISTANCES_FILENAME = PREPROCESSED_DIR / "neighbor_distances.npy"
FUSED_NODE_LABELS_FILENAME = PREPROCESSED_DIR / "fused_node_labels.npy"
FUSED_PAPER_EDGES_FILENAME = PREPROCESSED_DIR / "fused_paper_edges.npz"
FUSED_PAPER_EDGES_T_FILENAME = PREPROCESSED_DIR / "fused_paper_edges_t.npz"
K_FOLD_SPLITS_DIR = Path("k_fold_splits")
def get_raw_directory(data_root):
return Path(data_root) / "raw"
def get_preprocessed_directory(data_root):
return Path(data_root) / "preprocessed"
def _log_path_decorator(fn):
def _decorated_fn(path, **kwargs):
logging.info("Loading %s", path)
output = fn(path, **kwargs)
logging.info("Finish loading %s", path)
return output
return _decorated_fn
@_log_path_decorator
def load_csr(path, debug=False):
if debug:
# Dummy matrix for debugging.
return sp.csr_matrix(np.zeros([10, 10]))
return sp.load_npz(str(path))
@_log_path_decorator
def load_npy(path):
return np.load(str(path))
@functools.lru_cache()
def get_arrays(data_root="/data/",
use_fused_node_labels=True,
use_fused_node_adjacencies=True,
return_pca_embeddings=True,
k_fold_split_id=None,
return_adjacencies=True,
use_dummy_adjacencies=False):
"""Returns all arrays needed for training."""
logging.info("Starting to get files")
data_root = Path(data_root)
array_dict = {}
array_dict["paper_year"] = load_npy(data_root / RAW_NODE_YEAR_FILENAME)
if k_fold_split_id is None:
train_indices = load_npy(data_root / TRAIN_INDEX_FILENAME)
valid_indices = load_npy(data_root / VALID_INDEX_FILENAME)
else:
train_indices, valid_indices = get_train_and_valid_idx_for_split(
k_fold_split_id, num_splits=NUM_K_FOLD_SPLITS,
root_path=data_root / K_FOLD_SPLITS_DIR)
array_dict["train_indices"] = train_indices
array_dict["valid_indices"] = valid_indices
array_dict["test_indices"] = load_npy(data_root / TEST_INDEX_FILENAME)
if use_fused_node_labels:
array_dict["paper_label"] = load_npy(data_root / FUSED_NODE_LABELS_FILENAME)
else:
array_dict["paper_label"] = load_npy(data_root / RAW_NODE_LABELS_FILENAME)
if return_adjacencies:
logging.info("Starting to get adjacencies.")
if use_fused_node_adjacencies:
paper_paper_index = load_csr(
data_root / FUSED_PAPER_EDGES_FILENAME, debug=use_dummy_adjacencies)
paper_paper_index_t = load_csr(
data_root / FUSED_PAPER_EDGES_T_FILENAME, debug=use_dummy_adjacencies)
else:
paper_paper_index = load_csr(
data_root / EDGES_PAPER_PAPER_B, debug=use_dummy_adjacencies)
paper_paper_index_t = load_csr(
data_root / EDGES_PAPER_PAPER_B_T, debug=use_dummy_adjacencies)
array_dict.update(
dict(
author_institution_index=load_csr(
data_root / EDGES_AUTHOR_INSTITUTION,
debug=use_dummy_adjacencies),
institution_author_index=load_csr(
data_root / EDGES_INSTITUTION_AUTHOR,
debug=use_dummy_adjacencies),
author_paper_index=load_csr(
data_root / EDGES_AUTHOR_PAPER, debug=use_dummy_adjacencies),
paper_author_index=load_csr(
data_root / EDGES_PAPER_AUTHOR, debug=use_dummy_adjacencies),
paper_paper_index=paper_paper_index,
paper_paper_index_t=paper_paper_index_t,
))
if return_pca_embeddings:
array_dict["bert_pca_129"] = np.load(
data_root / PCA_MERGED_FEATURES_FILENAME, mmap_mode="r")
assert array_dict["bert_pca_129"].shape == (NUM_NODES, 129)
logging.info("Finish getting files")
# pytype: disable=attribute-error
assert array_dict["paper_year"].shape[0] == NUM_PAPERS
assert array_dict["paper_label"].shape[0] == NUM_PAPERS
if return_adjacencies and not use_dummy_adjacencies:
array_dict = _fix_adjacency_shapes(array_dict)
assert array_dict["paper_author_index"].shape == (NUM_PAPERS, NUM_AUTHORS)
assert array_dict["author_paper_index"].shape == (NUM_AUTHORS, NUM_PAPERS)
assert array_dict["paper_paper_index"].shape == (NUM_PAPERS, NUM_PAPERS)
assert array_dict["paper_paper_index_t"].shape == (NUM_PAPERS, NUM_PAPERS)
assert array_dict["institution_author_index"].shape == (
NUM_INSTITUTIONS, NUM_AUTHORS)
assert array_dict["author_institution_index"].shape == (
NUM_AUTHORS, NUM_INSTITUTIONS)
# pytype: enable=attribute-error
return array_dict
def add_nodes_year(graph, paper_year):
nodes = graph.nodes.copy()
indices = nodes["index"]
year = paper_year[np.minimum(indices, paper_year.shape[0] - 1)].copy()
year[nodes["type"] != 0] = 1900
nodes["year"] = year
return graph._replace(nodes=nodes)
def add_nodes_label(graph, paper_label):
nodes = graph.nodes.copy()
indices = nodes["index"]
label = paper_label[np.minimum(indices, paper_label.shape[0] - 1)]
label[nodes["type"] != 0] = 0
nodes["label"] = label
return graph._replace(nodes=nodes)
def add_nodes_embedding_from_array(graph, array):
"""Adds embeddings from the sstable_service for the indices."""
nodes = graph.nodes.copy()
indices = nodes["index"]
embedding_indices = indices.copy()
embedding_indices[nodes["type"] == 1] += NUM_PAPERS
embedding_indices[nodes["type"] == 2] += NUM_PAPERS + NUM_AUTHORS
# Gather the embeddings for the indices.
nodes["features"] = array[embedding_indices]
return graph._replace(nodes=nodes)
def get_graph_subsampling_dataset(
prefix, arrays, shuffle_indices, ratio_unlabeled_data_to_labeled_data,
max_nodes, max_edges,
**subsampler_kwargs):
"""Returns tf_dataset for online sampling."""
def generator():
labeled_indices = arrays[f"{prefix}_indices"]
if ratio_unlabeled_data_to_labeled_data > 0:
num_unlabeled_data_to_add = int(ratio_unlabeled_data_to_labeled_data *
labeled_indices.shape[0])
unlabeled_indices = np.random.choice(
NUM_PAPERS, size=num_unlabeled_data_to_add, replace=False)
root_node_indices = np.concatenate([labeled_indices, unlabeled_indices])
else:
root_node_indices = labeled_indices
if shuffle_indices:
root_node_indices = root_node_indices.copy()
np.random.shuffle(root_node_indices)
for index in root_node_indices:
graph = sub_sampler.subsample_graph(
index,
arrays["author_institution_index"],
arrays["institution_author_index"],
arrays["author_paper_index"],
arrays["paper_author_index"],
arrays["paper_paper_index"],
arrays["paper_paper_index_t"],
paper_years=arrays["paper_year"],
max_nodes=max_nodes,
max_edges=max_edges,
**subsampler_kwargs)
graph = add_nodes_label(graph, arrays["paper_label"])
graph = add_nodes_year(graph, arrays["paper_year"])
graph = tf_graphs.GraphsTuple(*graph)
yield graph
sample_graph = next(generator())
return tf.data.Dataset.from_generator(
generator,
output_signature=utils_tf.specs_from_graphs_tuple(sample_graph))
def paper_features_to_author_features(
author_paper_index, paper_features):
"""Averages paper features to authors."""
assert paper_features.shape[0] == NUM_PAPERS
assert author_paper_index.shape[0] == NUM_AUTHORS
author_features = np.zeros(
[NUM_AUTHORS, paper_features.shape[1]], dtype=paper_features.dtype)
for author_i in range(NUM_AUTHORS):
paper_indices = author_paper_index[author_i].indices
author_features[author_i] = paper_features[paper_indices].mean(
axis=0, dtype=np.float32)
if author_i % 10000 == 0:
logging.info("%d/%d", author_i, NUM_AUTHORS)
return author_features
def author_features_to_institution_features(
institution_author_index, author_features):
"""Averages author features to institutions."""
assert author_features.shape[0] == NUM_AUTHORS
assert institution_author_index.shape[0] == NUM_INSTITUTIONS
institution_features = np.zeros(
[NUM_INSTITUTIONS, author_features.shape[1]], dtype=author_features.dtype)
for institution_i in range(NUM_INSTITUTIONS):
author_indices = institution_author_index[institution_i].indices
institution_features[institution_i] = author_features[
author_indices].mean(axis=0, dtype=np.float32)
if institution_i % 10000 == 0:
logging.info("%d/%d", institution_i, NUM_INSTITUTIONS)
return institution_features
def generate_fused_paper_adjacency_matrix(neighbor_indices, neighbor_distances,
paper_paper_csr):
"""Generates fused adjacency matrix for identical nodes."""
# First construct set of identical node indices.
# NOTE: Since we take only top K=26 identical pairs for each node, this is not
# actually exhaustive. Also, if A and B are equal, and B and C are equal,
# this method would not necessarily detect A and C being equal.
# However, this should capture almost all cases.
logging.info("Generating fused paper adjacency matrix")
eps = 0.0
mask = ((neighbor_indices != np.mgrid[:neighbor_indices.shape[0], :1]) &
(neighbor_distances <= eps))
identical_pairs = list(map(tuple, np.nonzero(mask)))
del mask
# Have a csc version for fast column access.
paper_paper_csc = paper_paper_csr.tocsc()
# Construct new matrix as coo, starting off with original rows/cols.
paper_paper_coo = paper_paper_csr.tocoo()
new_rows = [paper_paper_coo.row]
new_cols = [paper_paper_coo.col]
for pair in tqdm.tqdm(identical_pairs):
# STEP ONE: First merge papers being cited by the pair.
# Add edges from second paper, to all papers cited by first paper.
cited_by_first = paper_paper_csr.getrow(pair[0]).nonzero()[1]
if cited_by_first.shape[0] > 0:
new_rows.append(pair[1] * np.ones_like(cited_by_first))
new_cols.append(cited_by_first)
# Add edges from first paper, to all papers cited by second paper.
cited_by_second = paper_paper_csr.getrow(pair[1]).nonzero()[1]
if cited_by_second.shape[0] > 0:
new_rows.append(pair[0] * np.ones_like(cited_by_second))
new_cols.append(cited_by_second)
# STEP TWO: Then merge papers that cite the pair.
# Add edges to second paper, from all papers citing the first paper.
citing_first = paper_paper_csc.getcol(pair[0]).nonzero()[0]
if citing_first.shape[0] > 0:
new_rows.append(citing_first)
new_cols.append(pair[1] * np.ones_like(citing_first))
# Add edges to first paper, from all papers citing the second paper.
citing_second = paper_paper_csc.getcol(pair[1]).nonzero()[0]
if citing_second.shape[0] > 0:
new_rows.append(citing_second)
new_cols.append(pair[0] * np.ones_like(citing_second))
logging.info("Done with adjacency loop")
paper_paper_coo_shape = paper_paper_coo.shape
del paper_paper_csr
del paper_paper_csc
del paper_paper_coo
# All done; now concatenate everything together and form new matrix.
new_rows = np.concatenate(new_rows)
new_cols = np.concatenate(new_cols)
return sp.coo_matrix(
(np.ones_like(new_rows, dtype=np.bool), (new_rows, new_cols)),
shape=paper_paper_coo_shape).tocsr()
def generate_k_fold_splits(
train_idx, valid_idx, output_path, num_splits=NUM_K_FOLD_SPLITS):
"""Generates splits adding fractions of the validation split to training."""
output_path = Path(output_path)
np.random.seed(42)
valid_idx = np.random.permutation(valid_idx)
# Split into `num_parts` (almost) identically sized arrays.
valid_idx_parts = np.array_split(valid_idx, num_splits)
for i in range(num_splits):
# Add all but the i'th subpart to training set.
new_train_idx = np.concatenate(
[train_idx, *valid_idx_parts[:i], *valid_idx_parts[i+1:]])
# i'th subpart is validation set.
new_valid_idx = valid_idx_parts[i]
train_path = output_path / f"train_idx_{i}_{num_splits}.npy"
valid_path = output_path / f"valid_idx_{i}_{num_splits}.npy"
np.save(train_path, new_train_idx)
np.save(valid_path, new_valid_idx)
logging.info("Saved: %s", train_path)
logging.info("Saved: %s", valid_path)
def get_train_and_valid_idx_for_split(
split_id: int,
num_splits: int,
root_path: str,
) -> Tuple[np.ndarray, np.ndarray]:
"""Returns train and valid indices for given split."""
new_train_idx = load_npy(f"{root_path}/train_idx_{split_id}_{num_splits}.npy")
new_valid_idx = load_npy(f"{root_path}/valid_idx_{split_id}_{num_splits}.npy")
return new_train_idx, new_valid_idx
def generate_fused_node_labels(neighbor_indices, neighbor_distances,
node_labels, train_indices, valid_indices,
test_indices):
"""Generates fused adjacency matrix for identical nodes."""
logging.info("Generating fused node labels")
valid_indices = set(valid_indices.tolist())
test_indices = set(test_indices.tolist())
valid_or_test_indices = valid_indices | test_indices
train_indices = train_indices[train_indices < neighbor_indices.shape[0]]
# Go through list of all pairs where one node is in training set, and
for i in tqdm.tqdm(train_indices):
for j in range(neighbor_indices.shape[1]):
other_index = neighbor_indices[i][j]
# if the other is not a validation or test node,
if other_index in valid_or_test_indices:
continue
# and they are identical,
if neighbor_distances[i][j] == 0:
# assign the label of the training node to the other node
node_labels[other_index] = node_labels[i]
return node_labels
def _pad_to_shape(
sparse_csr_matrix: sp.csr_matrix,
output_shape: Tuple[int, int]) -> sp.csr_matrix:
"""Pads a csr sparse matrix to the given shape."""
# We should not try to expand anything smaller.
assert np.all(sparse_csr_matrix.shape <= output_shape)
# Maybe it already has the right shape.
if sparse_csr_matrix.shape == output_shape:
return sparse_csr_matrix
# Append as many indptr elements as we need to match the leading size,
# This is achieved by just padding with copies of the last indptr element.
required_padding = output_shape[0] - sparse_csr_matrix.shape[0]
updated_indptr = np.concatenate(
[sparse_csr_matrix.indptr] +
[sparse_csr_matrix.indptr[-1:]] * required_padding,
axis=0)
# The change in trailing size does not have structural implications, it just
# determines the highest possible value for the indices, so it is sufficient
# to just pass the new output shape, with the correct trailing size.
return sp.csr.csr_matrix(
(sparse_csr_matrix.data,
sparse_csr_matrix.indices,
updated_indptr),
shape=output_shape)
def _fix_adjacency_shapes(
arrays: Dict[str, sp.csr.csr_matrix],
) -> Dict[str, sp.csr.csr_matrix]:
"""Fixes the shapes of the adjacency matrices."""
arrays = arrays.copy()
for key in ["author_institution_index",
"author_paper_index",
"paper_paper_index",
"institution_author_index",
"paper_author_index",
"paper_paper_index_t"]:
type_sender = key.split("_")[0]
type_receiver = key.split("_")[1]
arrays[key] = _pad_to_shape(
arrays[key], output_shape=(SIZES[type_sender], SIZES[type_receiver]))
return arrays
+297
View File
@@ -0,0 +1,297 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MAG240M-LSC datasets."""
import threading
from typing import NamedTuple, Optional
import jax
import jraph
from ml_collections import config_dict
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
# pylint: disable=g-bad-import-order
# pytype: disable=import-error
import batching_utils
import data_utils
# We only want to load these arrays once for all threads.
# `get_arrays` uses an LRU cache which is not thread safe.
LOADING_RAW_ARRAYS_LOCK = threading.Lock()
NUM_CLASSES = data_utils.NUM_CLASSES
_MAX_DEPTH_IN_SUBGRAPH = 3
class Batch(NamedTuple):
"""NamedTuple to represent batches of data."""
graph: jraph.GraphsTuple
node_labels: np.ndarray
label_mask: np.ndarray
central_node_mask: np.ndarray
node_indices: np.ndarray
absolute_node_indices: np.ndarray
def build_dataset_iterator(
data_root: str,
split: str,
dynamic_batch_size_config: config_dict.ConfigDict,
online_subsampling_kwargs: dict, # pylint: disable=g-bare-generic
debug: bool = False,
is_training: bool = True,
k_fold_split_id: Optional[int] = None,
ratio_unlabeled_data_to_labeled_data: float = 0.0,
use_all_labels_when_not_training: bool = False,
use_dummy_adjacencies: bool = False,
):
"""Returns an iterator over Batches from the dataset."""
if split == 'test':
use_all_labels_when_not_training = True
if not is_training:
ratio_unlabeled_data_to_labeled_data = 0.0
# Load the master data arrays.
with LOADING_RAW_ARRAYS_LOCK:
array_dict = data_utils.get_arrays(
data_root, k_fold_split_id=k_fold_split_id,
use_dummy_adjacencies=use_dummy_adjacencies)
node_labels = array_dict['paper_label'].reshape(-1)
train_indices = array_dict['train_indices'].astype(np.int32)
is_train_index = np.zeros(node_labels.shape[0], dtype=np.int32)
is_train_index[train_indices] = 1
valid_indices = array_dict['valid_indices'].astype(np.int32)
is_valid_index = np.zeros(node_labels.shape[0], dtype=np.int32)
is_valid_index[valid_indices] = 1
is_train_or_valid_index = is_train_index + is_valid_index
def sstable_to_intermediate_graph(graph):
indices = tf.cast(graph.nodes['index'], tf.int32)
first_index = indices[..., 0]
# Add an additional absolute index, but adding offsets to authors, and
# institution indices.
absolute_index = graph.nodes['index']
is_author = graph.nodes['type'] == 1
absolute_index = tf.where(
is_author, absolute_index + data_utils.NUM_PAPERS, absolute_index)
is_institution = graph.nodes['type'] == 2
absolute_index = tf.where(
is_institution,
absolute_index + data_utils.NUM_PAPERS + data_utils.NUM_AUTHORS,
absolute_index)
is_same_as_central_node = tf.math.equal(indices, first_index)
input_nodes = graph.nodes
graph = graph._replace(
nodes={
'one_hot_type':
tf.one_hot(tf.cast(input_nodes['type'], tf.int32), 3),
'one_hot_depth':
tf.one_hot(
tf.cast(input_nodes['depth'], tf.int32),
_MAX_DEPTH_IN_SUBGRAPH),
'year':
tf.expand_dims(input_nodes['year'], axis=-1),
'label':
tf.one_hot(
tf.cast(input_nodes['label'], tf.int32),
NUM_CLASSES),
'is_same_as_central_node':
is_same_as_central_node,
# Only first node in graph has a valid label.
'is_central_node':
tf.one_hot(0,
tf.shape(input_nodes['label'])[0]),
'index':
input_nodes['index'],
'absolute_index': absolute_index,
},
globals=tf.expand_dims(graph.globals, axis=-1),
)
return graph
ds = data_utils.get_graph_subsampling_dataset(
split,
array_dict,
shuffle_indices=is_training,
ratio_unlabeled_data_to_labeled_data=ratio_unlabeled_data_to_labeled_data,
max_nodes=dynamic_batch_size_config.n_node - 1, # Keep space for pads.
max_edges=dynamic_batch_size_config.n_edge,
**online_subsampling_kwargs)
if debug:
ds = ds.take(50)
ds = ds.map(
sstable_to_intermediate_graph,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
if is_training:
ds = ds.shard(jax.process_count(), jax.process_index())
ds = ds.shuffle(buffer_size=1 if debug else 128)
ds = ds.repeat()
ds = ds.prefetch(1 if debug else tf.data.experimental.AUTOTUNE)
np_ds = iter(tfds.as_numpy(ds))
batched_np_ds = batching_utils.dynamically_batch(
np_ds,
**dynamic_batch_size_config,
)
def intermediate_graph_to_batch(graph):
central_node_mask = graph.nodes['is_central_node']
label = graph.nodes['label']
node_indices = graph.nodes['index']
absolute_indices = graph.nodes['absolute_index']
### Construct label as a feature for non-central nodes.
# First do a lookup with node indices, with a np.minimum to ensure we do not
# index out of bounds due to num_authors being larger than num_papers.
is_same_as_central_node = graph.nodes['is_same_as_central_node']
capped_indices = np.minimum(node_indices, node_labels.shape[0] - 1)
label_as_feature = node_labels[capped_indices]
# Nodes which are not in train set should get `num_classes` label.
# Nodes in test set or non-arXiv nodes have -1 or nan labels.
# Mask out invalid labels and non-papers.
use_label_as_feature = np.logical_and(label_as_feature >= 0,
graph.nodes['one_hot_type'][..., 0])
if split == 'train' or not use_all_labels_when_not_training:
# Mask out validation papers and non-arxiv papers who
# got labels from fusing with arxiv papers.
use_label_as_feature = np.logical_and(is_train_index[capped_indices],
use_label_as_feature)
label_as_feature = np.where(use_label_as_feature, label_as_feature,
NUM_CLASSES)
# Mask out central node label in case it appears again.
label_as_feature = np.where(is_same_as_central_node, NUM_CLASSES,
label_as_feature)
# Nodes which are not papers get `NUM_CLASSES+1` label.
label_as_feature = np.where(graph.nodes['one_hot_type'][..., 0],
label_as_feature, NUM_CLASSES+1)
nodes = {
'label_as_feature': label_as_feature,
'year': graph.nodes['year'],
'bitstring_year': _get_bitstring_year_representation(
graph.nodes['year']),
'one_hot_type': graph.nodes['one_hot_type'],
'one_hot_depth': graph.nodes['one_hot_depth'],
}
graph = graph._replace(
nodes=nodes,
globals={},
)
is_train_or_valid_node = np.logical_and(
is_train_or_valid_index[capped_indices],
graph.nodes['one_hot_type'][..., 0])
if is_training:
label_mask = np.logical_and(central_node_mask, is_train_or_valid_node)
else:
# `label_mask` is used to index into valid central nodes by prediction
# calculator. Since that computation is only done when not training, and
# at that time we are guaranteed all central nodes have valid labels,
# we just set label_mask = central_node_mask when not training.
label_mask = central_node_mask
batch = Batch(
graph=graph,
node_labels=label,
central_node_mask=central_node_mask,
label_mask=label_mask,
node_indices=node_indices,
absolute_node_indices=absolute_indices)
# Transform integers into one-hots.
batch = _add_one_hot_features_to_batch(batch)
# Gather PCA features.
return _add_embeddings_to_batch(batch, array_dict['bert_pca_129'])
batch_list = []
for batch in batched_np_ds:
with jax.profiler.StepTraceAnnotation('batch_postprocessing'):
batch = intermediate_graph_to_batch(batch)
if is_training:
batch_list.append(batch)
if len(batch_list) == jax.local_device_count():
yield jax.device_put_sharded(batch_list, jax.local_devices())
batch_list = []
else:
yield batch
def _get_bitstring_year_representation(year: np.ndarray):
"""Return year as bitstring."""
min_year = 1900
max_training_year = 2018
offseted_year = np.minimum(year, max_training_year) - min_year
return np.unpackbits(offseted_year.astype(np.uint8), axis=-1)
def _np_one_hot(targets: np.ndarray, nb_classes: int):
res = np.zeros(targets.shape + (nb_classes,), dtype=np.float16)
np.put_along_axis(res, targets.astype(np.int32)[..., None], 1.0, axis=-1)
return res
def _get_one_hot_year_representation(
year: np.ndarray,
one_hot_type: np.ndarray,
):
"""Returns good representation for year."""
# Bucket edges found based on quantiles to bucket into 20 equal sized buckets.
bucket_edges = np.array([
1964, 1975, 1983, 1989, 1994, 1998, 2001, 2004,
2006, 2008, 2009, 2011, 2012, 2013, 2014, 2016,
2017, # 2018, 2019, 2020 contain last-year-of-train, eval, test nodes
])
year = np.squeeze(year, axis=-1)
year_id = np.searchsorted(bucket_edges, year)
is_paper = one_hot_type[..., 0]
bucket_id_for_non_paper = len(bucket_edges) + 1
bucket_id = np.where(is_paper, year_id, bucket_id_for_non_paper)
one_hot_year = _np_one_hot(bucket_id, len(bucket_edges) + 2)
return one_hot_year
def _add_one_hot_features_to_batch(batch: Batch) -> Batch:
"""Transforms integer features into one-hot features."""
nodes = batch.graph.nodes.copy()
nodes['one_hot_year'] = _get_one_hot_year_representation(
nodes['year'], nodes['one_hot_type'])
del nodes['year']
# NUM_CLASSES plus one category for papers for which a class is not provided
# and another for nodes that are not papers.
nodes['one_hot_label_as_feature'] = _np_one_hot(
nodes['label_as_feature'], NUM_CLASSES + 2)
del nodes['label_as_feature']
return batch._replace(graph=batch.graph._replace(nodes=nodes))
def _add_embeddings_to_batch(batch: Batch, embeddings: np.ndarray) -> Batch:
nodes = batch.graph.nodes.copy()
nodes['features'] = embeddings[batch.absolute_node_indices]
graph = batch.graph._replace(nodes=nodes)
return batch._replace(graph=graph)
+110
View File
@@ -0,0 +1,110 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Download data required for training and evaluating models."""
import pathlib
from absl import app
from absl import flags
from absl import logging
from google.cloud import storage
# pylint: disable=g-bad-import-order
import data_utils
Path = pathlib.Path
_BUCKET_NAME = 'deepmind-ogb-lsc'
_MAX_DOWNLOAD_ATTEMPTS = 5
FLAGS = flags.FLAGS
flags.DEFINE_enum('payload', None, ['data', 'models'],
'Download "data" or "models"?')
flags.DEFINE_string('task_root', None, 'Local task root directory')
DATA_RELATIVE_PATHS = (
data_utils.RAW_NODE_YEAR_FILENAME,
data_utils.TRAIN_INDEX_FILENAME,
data_utils.VALID_INDEX_FILENAME,
data_utils.TEST_INDEX_FILENAME,
data_utils.K_FOLD_SPLITS_DIR,
data_utils.FUSED_NODE_LABELS_FILENAME,
data_utils.FUSED_PAPER_EDGES_FILENAME,
data_utils.FUSED_PAPER_EDGES_T_FILENAME,
data_utils.EDGES_AUTHOR_INSTITUTION,
data_utils.EDGES_INSTITUTION_AUTHOR,
data_utils.EDGES_AUTHOR_PAPER,
data_utils.EDGES_PAPER_AUTHOR,
data_utils.PCA_MERGED_FEATURES_FILENAME,
)
class DataCorruptionError(Exception):
pass
def _get_gcs_root():
return Path('mag') / FLAGS.payload
def _get_gcs_bucket():
storage_client = storage.Client.create_anonymous_client()
return storage_client.bucket(_BUCKET_NAME)
def _write_blob_to_destination(blob, task_root, ignore_existing=True):
"""Write the blob."""
logging.info("Copying blob: '%s'", blob.name)
destination_path = Path(task_root) / Path(*Path(blob.name).parts[1:])
logging.info(" ... to: '%s'", str(destination_path))
if ignore_existing and destination_path.exists():
return
destination_path.parent.mkdir(parents=True, exist_ok=True)
checksum = 'crc32c'
for attempt in range(_MAX_DOWNLOAD_ATTEMPTS):
try:
blob.download_to_filename(destination_path.as_posix(), checksum=checksum)
except storage.client.resumable_media.common.DataCorruption:
pass
else:
break
else:
raise DataCorruptionError(f"Checksum ('{checksum}') for {blob.name} failed "
f'after {attempt + 1} attempts')
def main(unused_argv):
bucket = _get_gcs_bucket()
if FLAGS.payload == 'data':
relative_paths = DATA_RELATIVE_PATHS
else:
relative_paths = (None,)
for relative_path in relative_paths:
if relative_path is None:
relative_path = str(_get_gcs_root())
else:
relative_path = str(_get_gcs_root() / relative_path)
logging.info("Copying relative path: '%s'", relative_path)
blobs = bucket.list_blobs(prefix=relative_path)
for blob in blobs:
_write_blob_to_destination(blob, FLAGS.task_root)
if __name__ == '__main__':
flags.mark_flag_as_required('payload')
flags.mark_flag_as_required('task_root')
app.run(main)
+227
View File
@@ -0,0 +1,227 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Ensemble k-fold predictions and generate final submission file."""
import collections
import os
from absl import app
from absl import flags
from absl import logging
import dill
import jax
import numpy as np
from ogb import lsc
# pylint: disable=g-bad-import-order
import data_utils
import losses
_NUM_KFOLD_SPLITS = 10
FLAGS = flags.FLAGS
_DATA_ROOT = flags.DEFINE_string('data_root', None, 'Path to the data root')
_SPLIT = flags.DEFINE_enum('split', None, ['valid', 'test'], 'Data split')
_PREDICTIONS_PATH = flags.DEFINE_string(
'predictions_path', None, 'Path with the output of the k-fold models.')
_OUTPUT_PATH = flags.DEFINE_string('output_path', None, 'Output path.')
def _np_one_hot(targets: np.ndarray, nb_classes: int):
res = np.zeros(targets.shape + (nb_classes,), dtype=np.float32)
np.put_along_axis(res, targets.astype(np.int32)[..., None], 1.0, axis=-1)
return res
def ensemble_predictions(
node_idx_to_logits_list,
all_labels,
node_indices,
use_mode_break_tie_by_mean: bool = True,
):
"""Ensemble together predictions for each node and generate final predictions."""
# First, assert that each node has the same number of predictions to ensemble.
num_predictions_per_node = [
len(x) for x in node_idx_to_logits_list.values()
]
num_models = np.unique(num_predictions_per_node)
assert num_models.shape[0] == 1
num_models = num_models[0]
# Gather all logits, shape should be [num_nodes, num_models, num_classes].
all_logits = np.stack(
[np.stack(node_idx_to_logits_list[idx]) for idx in node_indices])
assert all_logits.shape == (node_indices.shape[0], num_models,
data_utils.NUM_CLASSES)
# Softmax on the final axis.
all_probs = jax.nn.softmax(all_logits, axis=-1)
# Take average across models axis to get probabilities.
mean_probs = np.mean(all_probs, axis=1)
# Assert there are no 2 equal logits for different classes.
max_logit_value = np.max(all_logits, axis=-1)
num_classes_with_max_value = (
all_logits == max_logit_value[..., None]).sum(axis=-1)
num_logit_ties = (num_classes_with_max_value > 1).sum()
if num_logit_ties:
logging.warn(
'Found %d models with the exact same logits for two of the classes. '
'`argmax` will choose the first.', num_logit_ties)
# Each model votes on one class per type.
all_votes = np.argmax(all_logits, axis=-1)
assert all_votes.shape == (node_indices.shape[0], num_models)
all_votes_one_hot = _np_one_hot(all_votes, data_utils.NUM_CLASSES)
assert all_votes_one_hot.shape == (node_indices.shape[0], num_models,
data_utils.NUM_CLASSES)
num_votes_per_class = np.sum(all_votes_one_hot, axis=1)
assert num_votes_per_class.shape == (
node_indices.shape[0], data_utils.NUM_CLASSES)
if use_mode_break_tie_by_mean:
# Slight hack, give high weight to votes (any number > 1 works really)
# and add probabilities between [0, 1] per class to tie-break only within
# classes with equal votes.
total_score = 10 * num_votes_per_class + mean_probs
else:
# Just take mean.
total_score = mean_probs
ensembled_logits = np.log(total_score)
return losses.Predictions(
node_indices=node_indices,
labels=all_labels,
logits=ensembled_logits,
predictions=np.argmax(ensembled_logits, axis=-1),
)
def load_predictions(predictions_path, split):
"""Loads set of predictions made by given XID."""
# Generate list of predictions per node.
# Note for validation each validation index is only present in exactly 1
# model of the k-fold, however for test it is present in all of them.
node_idx_to_logits_list = collections.defaultdict(list)
# For the 10 models in the ensemble.
for i in range(_NUM_KFOLD_SPLITS):
path = os.path.join(predictions_path, str(i))
# Find subdirectories.
# Directories will be something like:
# os.path.join(path, "step_104899_2021-06-14T18:20:05", "(test|valid).dill")
# So we make sure there is only one.
candidates = []
for date_str in os.listdir(path):
candidate_path = os.path.join(path, date_str, f'{split}.dill')
if os.path.exists(candidate_path):
candidates.append(candidate_path)
if not candidates:
raise ValueError(f'No {split} predictions found at {path}')
elif len(candidates) > 1:
raise ValueError(f'Found more than one {split} predictions: {candidates}')
path_for_kth_model_predictions = candidates[0]
with open(path_for_kth_model_predictions, 'rb') as f:
results = dill.load(f)
logging.info('Loaded %s', path_for_kth_model_predictions)
for (node_idx, logits) in zip(results.node_indices,
results.logits):
node_idx_to_logits_list[node_idx].append(logits)
return node_idx_to_logits_list
def generate_ensembled_predictions(
data_root: str, predictions_path: str, split: str) -> losses.Predictions:
"""Ensemble checkpoints from all WIDs in XID and generates submission file."""
array_dict = data_utils.get_arrays(
data_root=data_root,
return_pca_embeddings=False,
return_adjacencies=False)
# Load all valid and test predictions.
node_idx_to_logits_list = load_predictions(predictions_path, split)
# Assert that the indices loaded are as expected.
expected_idx = array_dict[f'{split}_indices']
idx_found = np.array(list(node_idx_to_logits_list.keys()))
assert np.all(np.sort(idx_found) == expected_idx)
if split == 'valid':
true_labels = array_dict['paper_label'][expected_idx.astype(np.int32)]
else:
# Don't know the test labels.
true_labels = np.full(expected_idx.shape, np.nan)
# Ensemble together all predictions.
return ensemble_predictions(
node_idx_to_logits_list, true_labels, expected_idx)
def evaluate_validation(valid_predictions):
evaluator = lsc.MAG240MEvaluator()
evaluator_ouput = evaluator.eval(
dict(y_pred=valid_predictions.predictions.astype(np.float64),
y_true=valid_predictions.labels))
logging.info(
'Validation accuracy as reported by MAG240MEvaluator: %s',
evaluator_ouput)
def save_test_submission_file(test_predictions, output_dir):
evaluator = lsc.MAG240MEvaluator()
evaluator.save_test_submission(
dict(y_pred=test_predictions.predictions.astype(np.float64)), output_dir)
logging.info('Test submission file generated at %s', output_dir)
def main(argv):
del argv
split = _SPLIT.value
ensembled_predictions = generate_ensembled_predictions(
data_root=_DATA_ROOT.value,
predictions_path=_PREDICTIONS_PATH.value,
split=split)
output_dir = _OUTPUT_PATH.value
os.makedirs(output_dir, exist_ok=True)
if split == 'valid':
evaluate_validation(ensembled_predictions)
elif split == 'test':
save_test_submission_file(ensembled_predictions, output_dir)
ensembled_predictions_path = os.path.join(output_dir, f'{split}.dill')
assert not os.path.exists(ensembled_predictions_path)
with open(ensembled_predictions_path, 'wb') as f:
dill.dump(ensembled_predictions, f)
logging.info(
'%s predictions stored at %s', split, ensembled_predictions_path)
if __name__ == '__main__':
app.run(main)
File diff suppressed because it is too large Load Diff
+51
View File
@@ -0,0 +1,51 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates the k-fold validation splits."""
import os
from absl import app
from absl import flags
import data_utils
_DATA_ROOT = flags.DEFINE_string(
'data_root', None, required=True,
help='Path containing the downloaded data.')
_OUTPUT_DIR = flags.DEFINE_string(
'output_dir', None, required=True,
help='Output directory to write the splits to')
def main(argv):
del argv
array_dict = data_utils.get_arrays(
data_root=_DATA_ROOT.value,
return_pca_embeddings=False,
return_adjacencies=False)
os.makedirs(_OUTPUT_DIR.value, exist_ok=True)
data_utils.generate_k_fold_splits(
train_idx=array_dict['train_indices'],
valid_idx=array_dict['valid_indices'],
output_path=_OUTPUT_DIR.value,
num_splits=data_utils.NUM_K_FOLD_SPLITS)
if __name__ == '__main__':
app.run(main)
+199
View File
@@ -0,0 +1,199 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Losses and related utilities."""
from typing import Mapping, Tuple, Sequence, NamedTuple, Dict, Optional
import jax
import jax.numpy as jnp
import jraph
import numpy as np
# pylint: disable=g-bad-import-order
import datasets
LogsDict = Mapping[str, jnp.ndarray]
class Predictions(NamedTuple):
node_indices: np.ndarray
labels: np.ndarray
predictions: np.ndarray
logits: np.ndarray
def node_classification_loss(
logits: jnp.ndarray,
batch: datasets.Batch,
extra_stats: bool = False,
) -> Tuple[jnp.ndarray, LogsDict]:
"""Gets node-wise classification loss and statistics."""
log_probs = jax.nn.log_softmax(logits)
loss = -jnp.sum(log_probs * batch.node_labels, axis=-1)
num_valid = jnp.sum(batch.label_mask)
labels = jnp.argmax(batch.node_labels, axis=-1)
is_correct = (jnp.argmax(log_probs, axis=-1) == labels)
num_correct = jnp.sum(is_correct * batch.label_mask)
loss = jnp.sum(loss * batch.label_mask) / (num_valid + 1e-8)
accuracy = num_correct / (num_valid + 1e-8)
entropy = -jnp.mean(jnp.sum(jax.nn.softmax(logits) * log_probs, axis=-1))
stats = {
'classification_loss': loss,
'prediction_entropy': entropy,
'accuracy': accuracy,
'num_valid': num_valid,
'num_correct': num_correct,
}
if extra_stats:
for k in range(1, 6):
stats[f'top_{k}_correct'] = topk_correct(logits, labels,
batch.label_mask, k)
return loss, stats
def get_predictions_labels_and_logits(
logits: jnp.ndarray,
batch: datasets.Batch,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Gets prediction labels and logits."""
mask = batch.label_mask > 0.
indices = batch.node_indices[mask]
logits = logits[mask]
predictions = jnp.argmax(logits, axis=-1)
labels = jnp.argmax(batch.node_labels[mask], axis=-1)
return indices, predictions, labels, logits
def topk_correct(
logits: jnp.ndarray,
labels: jnp.ndarray,
valid_mask: jnp.ndarray,
topk: int,
) -> jnp.ndarray:
"""Calculates top-k accuracy."""
pred_ranking = jnp.argsort(logits, axis=1)[:, ::-1]
pred_ranking = pred_ranking[:, :topk]
is_correct = jnp.any(pred_ranking == labels[:, jnp.newaxis], axis=1)
return (is_correct * valid_mask).sum()
def ensemble_predictions_by_probability_average(
predictions_list: Sequence[Predictions]) -> Predictions:
"""Ensemble predictions by ensembling the probabilities."""
_assert_consistent_predictions(predictions_list)
all_probs = np.stack([
jax.nn.softmax(predictions.logits, axis=-1)
for predictions in predictions_list
],
axis=0)
ensembled_logits = np.log(all_probs.mean(0))
return predictions_list[0]._replace(
logits=ensembled_logits, predictions=np.argmax(ensembled_logits, axis=-1))
def get_accuracy_dict(predictions: Predictions) -> Dict[str, float]:
"""Returns the accuracy dict."""
output_dict = {}
output_dict['num_valid'] = predictions.predictions.shape[0]
matches = (predictions.labels == predictions.predictions)
output_dict['accuracy'] = matches.mean()
pred_ranking = jnp.argsort(predictions.logits, axis=1)[:, ::-1]
for k in range(1, 6):
matches = jnp.any(
pred_ranking[:, :k] == predictions.labels[:, None], axis=1)
output_dict[f'top_{k}_correct'] = matches.mean()
return output_dict
def bgrl_loss(
first_online_predictions: jnp.ndarray,
second_target_projections: jnp.ndarray,
second_online_predictions: jnp.ndarray,
first_target_projections: jnp.ndarray,
symmetrize: bool,
valid_mask: jnp.ndarray,
) -> Tuple[jnp.ndarray, LogsDict]:
"""Implements BGRL loss."""
first_side_node_loss = jnp.sum(
jnp.square(
_l2_normalize(first_online_predictions, axis=-1) -
_l2_normalize(second_target_projections, axis=-1)),
axis=-1)
if symmetrize:
second_side_node_loss = jnp.sum(
jnp.square(
_l2_normalize(second_online_predictions, axis=-1) -
_l2_normalize(first_target_projections, axis=-1)),
axis=-1)
node_loss = first_side_node_loss + second_side_node_loss
else:
node_loss = first_side_node_loss
loss = (node_loss * valid_mask).sum() / (valid_mask.sum() + 1e-6)
return loss, dict(bgrl_loss=loss)
def get_corrupted_view(
graph: jraph.GraphsTuple,
feature_drop_prob: float,
edge_drop_prob: float,
rng_key: jnp.ndarray,
) -> jraph.GraphsTuple:
"""Returns corrupted graph view."""
node_key, edge_key = jax.random.split(rng_key)
def mask_feature(x):
mask = jax.random.bernoulli(node_key, 1 - feature_drop_prob, x.shape)
return x * mask
# Randomly mask features with fixed probability.
nodes = jax.tree_map(mask_feature, graph.nodes)
# Simulate dropping of edges by changing genuine edges to self-loops on
# the padded node.
num_edges = graph.senders.shape[0]
last_node_idx = graph.n_node.sum() - 1
edge_mask = jax.random.bernoulli(edge_key, 1 - edge_drop_prob, [num_edges])
senders = jnp.where(edge_mask, graph.senders, last_node_idx)
receivers = jnp.where(edge_mask, graph.receivers, last_node_idx)
# Note that n_edge will now be invalid since edges in the middle of the list
# will correspond to the final graph. Set n_edge to None to ensure we do not
# accidentally use this.
return graph._replace(
nodes=nodes,
senders=senders,
receivers=receivers,
n_edge=None,
)
def _assert_consistent_predictions(predictions_list: Sequence[Predictions]):
first_predictions = predictions_list[0]
for predictions in predictions_list:
assert np.all(predictions.node_indices == first_predictions.node_indices)
assert np.all(predictions.labels == first_predictions.labels)
assert np.all(
predictions.predictions == np.argmax(predictions.logits, axis=-1))
def _l2_normalize(
x: jnp.ndarray,
axis: Optional[int] = None,
epsilon: float = 1e-6,
) -> jnp.ndarray:
return x * jax.lax.rsqrt(
jnp.sum(jnp.square(x), axis=axis, keepdims=True) + epsilon)
+270
View File
@@ -0,0 +1,270 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MAG240M-LSC models."""
from typing import Callable, NamedTuple, Sequence
import haiku as hk
import jax
import jax.numpy as jnp
import jraph
_REDUCER_NAMES = {
'sum':
jax.ops.segment_sum,
'mean':
jraph.segment_mean,
'softmax':
jraph.segment_softmax,
}
class ModelOutput(NamedTuple):
node_embeddings: jnp.ndarray
node_embedding_projections: jnp.ndarray
node_projection_predictions: jnp.ndarray
node_logits: jnp.ndarray
def build_update_fn(
name: str,
output_sizes: Sequence[int],
activation: Callable[[jnp.ndarray], jnp.ndarray],
normalization_type: str,
is_training: bool,
):
"""Builds update function."""
def single_mlp(inner_name: str):
"""Creates a single MLP performing the update."""
mlp = hk.nets.MLP(
output_sizes=output_sizes,
name=inner_name,
activation=activation)
mlp = jraph.concatenated_args(mlp)
if normalization_type == 'layer_norm':
norm = hk.LayerNorm(
axis=-1,
create_scale=True,
create_offset=True,
name=name + '_layer_norm')
elif normalization_type == 'batch_norm':
batch_norm = hk.BatchNorm(
create_scale=True,
create_offset=True,
decay_rate=0.9,
name=f'{inner_name}_batch_norm',
cross_replica_axis=None if hk.running_init() else 'i',
)
norm = lambda x: batch_norm(x, is_training)
elif normalization_type == 'none':
return mlp
else:
raise ValueError(f'Unknown normalization type {normalization_type}')
return jraph.concatenated_args(hk.Sequential([mlp, norm]))
return single_mlp(f'{name}_homogeneous')
def build_gn(
output_sizes: Sequence[int],
activation: Callable[[jnp.ndarray], jnp.ndarray],
suffix: str,
use_sent_edges: bool,
is_training: bool,
dropedge_rate: float,
normalization_type: str,
aggregation_function: str,
):
"""Builds an InteractionNetwork with MLP update functions."""
node_update_fn = build_update_fn(
f'node_processor_{suffix}',
output_sizes,
activation=activation,
normalization_type=normalization_type,
is_training=is_training,
)
edge_update_fn = build_update_fn(
f'edge_processor_{suffix}',
output_sizes,
activation=activation,
normalization_type=normalization_type,
is_training=is_training,
)
def maybe_dropedge(x):
"""Dropout on edge messages."""
if not is_training:
return x
return x * hk.dropout(
hk.next_rng_key(),
dropedge_rate,
jnp.ones([x.shape[0], 1]),
)
dropped_edge_update_fn = lambda *args: maybe_dropedge(edge_update_fn(*args))
return jraph.InteractionNetwork(
update_edge_fn=dropped_edge_update_fn,
update_node_fn=node_update_fn,
aggregate_edges_for_nodes_fn=_REDUCER_NAMES[aggregation_function],
include_sent_messages_in_node_update=use_sent_edges,
)
def _get_activation_fn(name: str) -> Callable[[jnp.ndarray], jnp.ndarray]:
if name == 'identity':
return lambda x: x
if hasattr(jax.nn, name):
return getattr(jax.nn, name)
raise ValueError('Unknown activation function %s specified. '
'See https://jax.readthedocs.io/en/latest/jax.nn.html'
'for the list of supported function names.')
class NodePropertyEncodeProcessDecode(hk.Module):
"""Node Property Prediction Encode Process Decode Model."""
def __init__(
self,
mlp_hidden_sizes: Sequence[int],
latent_size: int,
num_classes: int,
num_message_passing_steps: int = 2,
activation: str = 'relu',
dropout_rate: float = 0.0,
dropedge_rate: float = 0.0,
use_sent_edges: bool = False,
disable_edge_updates: bool = False,
normalization_type: str = 'layer_norm',
aggregation_function: str = 'sum',
name='NodePropertyEncodeProcessDecode',
):
super().__init__(name=name)
self._num_classes = num_classes
self._latent_size = latent_size
self._output_sizes = list(mlp_hidden_sizes) + [latent_size]
self._num_message_passing_steps = num_message_passing_steps
self._activation = _get_activation_fn(activation)
self._dropout_rate = dropout_rate
self._dropedge_rate = dropedge_rate
self._use_sent_edges = use_sent_edges
self._disable_edge_updates = disable_edge_updates
self._normalization_type = normalization_type
self._aggregation_function = aggregation_function
def _dropout_graph(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
node_key, edge_key = hk.next_rng_keys(2)
nodes = hk.dropout(node_key, self._dropout_rate, graph.nodes)
edges = graph.edges
if not self._disable_edge_updates:
edges = hk.dropout(edge_key, self._dropout_rate, edges)
return graph._replace(nodes=nodes, edges=edges)
def _encode(
self,
graph: jraph.GraphsTuple,
is_training: bool,
) -> jraph.GraphsTuple:
node_embed_fn = build_update_fn(
'node_encoder',
self._output_sizes,
activation=self._activation,
normalization_type=self._normalization_type,
is_training=is_training,
)
edge_embed_fn = build_update_fn(
'edge_encoder',
self._output_sizes,
activation=self._activation,
normalization_type=self._normalization_type,
is_training=is_training,
)
gn = jraph.GraphMapFeatures(edge_embed_fn, node_embed_fn)
graph = gn(graph)
if is_training:
graph = self._dropout_graph(graph)
return graph
def _process(
self,
graph: jraph.GraphsTuple,
is_training: bool,
) -> jraph.GraphsTuple:
for idx in range(self._num_message_passing_steps):
net = build_gn(
output_sizes=self._output_sizes,
activation=self._activation,
suffix=str(idx),
use_sent_edges=self._use_sent_edges,
is_training=is_training,
dropedge_rate=self._dropedge_rate,
normalization_type=self._normalization_type,
aggregation_function=self._aggregation_function)
residual_graph = net(graph)
graph = graph._replace(nodes=graph.nodes + residual_graph.nodes)
if not self._disable_edge_updates:
graph = graph._replace(edges=graph.edges + residual_graph.edges)
if is_training:
graph = self._dropout_graph(graph)
return graph
def _node_mlp(
self,
graph: jraph.GraphsTuple,
is_training: bool,
output_size: int,
name: str,
) -> jnp.ndarray:
decoder_sizes = list(self._output_sizes[:-1]) + [output_size]
net = build_update_fn(
name,
decoder_sizes,
self._activation,
normalization_type=self._normalization_type,
is_training=is_training,
)
return net(graph.nodes)
def __call__(
self,
graph: jraph.GraphsTuple,
is_training: bool,
stop_gradient_embedding_to_logits: bool = False,
) -> ModelOutput:
# Note that these update configs may need to change if
# we switch back to GraphNetwork rather than InteractionNetwork.
graph = self._encode(graph, is_training)
graph = self._process(graph, is_training)
node_embeddings = graph.nodes
node_projections = self._node_mlp(graph, is_training, self._latent_size,
'projector')
node_predictions = self._node_mlp(
graph._replace(nodes=node_projections),
is_training,
self._latent_size,
'predictor',
)
if stop_gradient_embedding_to_logits:
graph = jax.tree_map(jax.lax.stop_gradient, graph)
node_logits = self._node_mlp(graph, is_training, self._num_classes,
'logits_decoder')
return ModelOutput(
node_embeddings=node_embeddings,
node_logits=node_logits,
node_embedding_projections=node_projections,
node_projection_predictions=node_predictions,
)
+175
View File
@@ -0,0 +1,175 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Find neighborhoods around paper feature embeddings."""
import pathlib
from absl import app
from absl import flags
from absl import logging
import annoy
import numpy as np
import scipy.sparse as sp
# pylint: disable=g-bad-import-order
import data_utils
Path = pathlib.Path
_PAPER_PAPER_B_PATH = 'ogb_mag_adjacencies/paper_paper_b.npz'
FLAGS = flags.FLAGS
flags.DEFINE_string('data_root', None, 'Data root directory')
def _read_paper_pca_features():
data_root = Path(FLAGS.data_root)
path = data_root / data_utils.PCA_PAPER_FEATURES_FILENAME
with open(path, 'rb') as fid:
return np.load(fid)
def _read_adjacency_indices():
# Get adjacencies.
return data_utils.get_arrays(
data_root=FLAGS.data_root,
use_fused_node_labels=False,
use_fused_node_adjacencies=False,
return_pca_embeddings=False,
)
def build_annoy_index(features):
"""Build the Annoy index."""
logging.info('Building annoy index')
num_vectors, vector_size = features.shape
annoy_index = annoy.AnnoyIndex(vector_size, 'euclidean')
for i, x in enumerate(features):
annoy_index.add_item(i, x)
if i % 1000000 == 0:
logging.info('Adding: %d / %d (%.3g %%)', i, num_vectors,
100 * i / num_vectors)
n_trees = 10
_ = annoy_index.build(n_trees)
return annoy_index
def _get_annoy_index_path():
return Path(FLAGS.data_root) / data_utils.PREPROCESSED_DIR / 'annoy_index.ann'
def save_annoy_index(annoy_index):
logging.info('Saving annoy index')
index_path = _get_annoy_index_path()
index_path.parent.mkdir(parents=True, exist_ok=True)
annoy_index.save(str(index_path))
def read_annoy_index(features):
index_path = _get_annoy_index_path()
vector_size = features.shape[1]
annoy_index = annoy.AnnoyIndex(vector_size, 'euclidean')
annoy_index.load(str(index_path))
return annoy_index
def compute_neighbor_indices_and_distances(features):
"""Use the pre-built Annoy index to compute neighbor indices and distances."""
logging.info('Computing neighbors and distances')
annoy_index = read_annoy_index(features)
num_vectors = features.shape[0]
k = 20
pad_k = 5
search_k = -1
neighbor_indices = np.zeros([num_vectors, k + pad_k + 1], dtype=np.int32)
neighbor_distances = np.zeros([num_vectors, k + pad_k + 1], dtype=np.float32)
for i in range(num_vectors):
neighbor_indices[i], neighbor_distances[i] = annoy_index.get_nns_by_item(
i, k + pad_k + 1, search_k=search_k, include_distances=True)
if i % 10000 == 0:
logging.info('Finding neighbors %d / %d', i, num_vectors)
return neighbor_indices, neighbor_distances
def _write_neighbors(neighbor_indices, neighbor_distances):
"""Write neighbor indices and distances."""
logging.info('Writing neighbors')
indices_path = Path(FLAGS.data_root) / data_utils.NEIGHBOR_INDICES_FILENAME
distances_path = (
Path(FLAGS.data_root) / data_utils.NEIGHBOR_DISTANCES_FILENAME)
indices_path.parent.mkdir(parents=True, exist_ok=True)
distances_path.parent.mkdir(parents=True, exist_ok=True)
with open(indices_path, 'wb') as fid:
np.save(fid, neighbor_indices)
with open(distances_path, 'wb') as fid:
np.save(fid, neighbor_distances)
def _write_fused_edges(fused_paper_adjacency_matrix):
"""Write fused edges."""
data_root = Path(FLAGS.data_root)
edges_path = data_root / data_utils.FUSED_PAPER_EDGES_FILENAME
edges_t_path = data_root / data_utils.FUSED_PAPER_EDGES_T_FILENAME
edges_path.parent.mkdir(parents=True, exist_ok=True)
edges_t_path.parent.mkdir(parents=True, exist_ok=True)
with open(edges_path, 'wb') as fid:
sp.save_npz(fid, fused_paper_adjacency_matrix)
with open(edges_t_path, 'wb') as fid:
sp.save_npz(fid, fused_paper_adjacency_matrix.T)
def _write_fused_nodes(fused_node_labels):
"""Write fused nodes."""
labels_path = Path(FLAGS.data_root) / data_utils.FUSED_NODE_LABELS_FILENAME
labels_path.parent.mkdir(parents=True, exist_ok=True)
with open(labels_path, 'wb') as fid:
np.save(fid, fused_node_labels)
def main(unused_argv):
paper_pca_features = _read_paper_pca_features()
# Find neighbors.
annoy_index = build_annoy_index(paper_pca_features)
save_annoy_index(annoy_index)
neighbor_indices, neighbor_distances = compute_neighbor_indices_and_distances(
paper_pca_features)
del paper_pca_features
_write_neighbors(neighbor_indices, neighbor_distances)
data = _read_adjacency_indices()
paper_paper_csr = data['paper_paper_index']
paper_label = data['paper_label']
train_indices = data['train_indices']
valid_indices = data['valid_indices']
test_indices = data['test_indices']
del data
fused_paper_adjacency_matrix = data_utils.generate_fused_paper_adjacency_matrix(
neighbor_indices, neighbor_distances, paper_paper_csr)
_write_fused_edges(fused_paper_adjacency_matrix)
del fused_paper_adjacency_matrix
del paper_paper_csr
fused_node_labels = data_utils.generate_fused_node_labels(
neighbor_indices, neighbor_distances, paper_label, train_indices,
valid_indices, test_indices)
_write_fused_nodes(fused_node_labels)
if __name__ == '__main__':
flags.mark_flag_as_required('data_root')
app.run(main)
+67
View File
@@ -0,0 +1,67 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
set -e
while getopts ":i:o:" opt; do
case ${opt} in
i )
INPUT_DIR=$OPTARG
;;
o )
TASK_ROOT=$OPTARG
;;
\? )
echo "Usage: organize_data.sh -i <Downloaded data dir> -o <Task root directory>"
;;
: )
echo "Invalid option: $OPTARG requires an argument" 1>&2
;;
esac
done
shift $((OPTIND -1))
# Get this script's directory.
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
if [[ -z "${INPUT_DIR}" ]]; then
echo "Need INPUT_DIR argument (-i <INPUT_DIR>)"
exit 1
fi
if [[ -z "${TASK_ROOT}" ]]; then
echo "Need TASK_ROOT argument (-o <TASK_ROOT>)"
exit 1
fi
DATA_ROOT="${TASK_ROOT}"/data
# Create raw directory to move all files to it.
mkdir "${INPUT_DIR}"/mag240m_kddcup2021/raw
mv "${INPUT_DIR}"/mag240m_kddcup2021/processed/paper/node_feat.npy \
"${INPUT_DIR}"/mag240m_kddcup2021/processed/paper/node_label.npy \
"${INPUT_DIR}"/mag240m_kddcup2021/processed/paper/node_year.npy \
"${DATA_ROOT}"/raw
mv "${INPUT_DIR}"/mag240m_kddcup2021/processed/author___affiliated_with___institution/edge_index.npy \
"${DATA_ROOT}"/raw/author_affiliated_with_institution_edges.npy
mv "${ROOT}"/mag240m_kddcup2021/processed/author___writes___paper/edge_index.npy \
"${DATA_ROOT}"/raw/author_writes_paper_edges.npy
mv "${ROOT}"/mag240m_kddcup2021/processed/paper___cites___paper/edge_index.npy \
"${DATA_ROOT}"/raw/paper_cites_paper_edges.npy
# Split and save the train/valid/test indices to the raw directory, with names
"train_idx.npy", "valid_idx.npy", "test_idx.npy":
python3 "${SCRIPT_DIR}"/split_and_save_indices.py --data_root=${DATA_ROOT}
+166
View File
@@ -0,0 +1,166 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Apply PCA to the papers' BERT features.
Compute papers' PCA features.
Recompute author and institution features from the paper PCA features.
"""
import pathlib
import time
from absl import app
from absl import flags
from absl import logging
import numpy as np
# pylint: disable=g-bad-import-order
import data_utils
Path = pathlib.Path
_NUMBER_OF_PAPERS_TO_ESTIMATE_PCA_ON = 1000000 # None indicates all.
FLAGS = flags.FLAGS
flags.DEFINE_string('data_root', None, 'Data root directory')
def _sample_vectors(vectors, num_samples, seed=0):
"""Randomly sample some vectors."""
rand = np.random.RandomState(seed=seed)
indices = rand.choice(vectors.shape[0], size=num_samples, replace=False)
return vectors[indices]
def _pca(feat):
"""Returns evals (variances), evecs (rows are principal components)."""
cov = np.cov(feat.T)
_, evals, evecs = np.linalg.svd(cov, full_matrices=True)
return evals, evecs
def _read_raw_paper_features():
"""Load raw paper features."""
path = Path(FLAGS.data_root) / data_utils.RAW_NODE_FEATURES_FILENAME
try: # Use mmap if possible.
features = np.load(path, mmap_mode='r')
except FileNotFoundError:
with open(path, 'rb') as fid:
features = np.load(fid)
return features
def _get_principal_components(features,
num_principal_components=129,
num_samples=10000,
seed=2,
dtype='f4'):
"""Estimate PCA features."""
sample = _sample_vectors(
features[:_NUMBER_OF_PAPERS_TO_ESTIMATE_PCA_ON], num_samples, seed=seed)
# Compute PCA basis.
_, evecs = _pca(sample)
return evecs[:num_principal_components].T.astype(dtype)
def _project_features_onto_principal_components(features,
principal_components,
block_size=1000000):
"""Apply PCA iteratively."""
num_principal_components = principal_components.shape[1]
dtype = principal_components.dtype
num_vectors = features.shape[0]
num_features = features.shape[0]
num_blocks = (num_features - 1) // block_size + 1
pca_features = np.empty([num_vectors, num_principal_components], dtype=dtype)
# Loop through in blocks.
start_time = time.time()
for i in range(num_blocks):
i_start = i * block_size
i_end = (i + 1) * block_size
f = np.array(features[i_start:i_end].copy())
pca_features[i_start:i_end] = np.dot(f, principal_components).astype(dtype)
del f
elapsed_time = time.time() - start_time
time_left = elapsed_time / (i + 1) * (num_blocks - i - 1)
logging.info('Features %d / %d. Elapsed time %.1f. Time left: %.1f', i_end,
num_vectors, elapsed_time, time_left)
return pca_features
def _read_adjacency_indices():
# Get adjacencies.
return data_utils.get_arrays(
data_root=FLAGS.data_root,
use_fused_node_labels=False,
use_fused_node_adjacencies=False,
return_pca_embeddings=False,
)
def _compute_author_pca_features(paper_pca_features, index_arrays):
return data_utils.paper_features_to_author_features(
index_arrays['author_paper_index'], paper_pca_features)
def _compute_institution_pca_features(author_pca_features, index_arrays):
return data_utils.author_features_to_institution_features(
index_arrays['institution_author_index'], author_pca_features)
def _write_array(path, array):
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, 'wb') as fid:
np.save(fid, array)
def main(unused_argv):
data_root = Path(FLAGS.data_root)
raw_paper_features = _read_raw_paper_features()
principal_components = _get_principal_components(raw_paper_features)
paper_pca_features = _project_features_onto_principal_components(
raw_paper_features, principal_components)
del raw_paper_features
del principal_components
paper_pca_path = data_root / data_utils.PCA_PAPER_FEATURES_FILENAME
author_pca_path = data_root / data_utils.PCA_AUTHOR_FEATURES_FILENAME
institution_pca_path = (
data_root / data_utils.PCA_INSTITUTION_FEATURES_FILENAME)
merged_pca_path = data_root / data_utils.PCA_MERGED_FEATURES_FILENAME
_write_array(paper_pca_path, paper_pca_features)
# Compute author and institution features from paper PCA features.
index_arrays = _read_adjacency_indices()
author_pca_features = _compute_author_pca_features(paper_pca_features,
index_arrays)
_write_array(author_pca_path, author_pca_features)
institution_pca_features = _compute_institution_pca_features(
author_pca_features, index_arrays)
_write_array(institution_pca_path, institution_pca_features)
merged_pca_features = np.concatenate(
[paper_pca_features, author_pca_features, institution_pca_features],
axis=0)
del author_pca_features
del institution_pca_features
_write_array(merged_pca_path, merged_pca_features)
if __name__ == '__main__':
flags.mark_flag_as_required('data_root')
app.run(main)
+23
View File
@@ -0,0 +1,23 @@
wheel
absl-py>=0.12.0
chex>=0.0.7
dill>=0.3.3
dm-haiku>=0.0.4
# jaxline
git+https://github.com/deepmind/jaxline@927695a0b88e480d085590736e2daa36c2f2c68b
optax>=0.0.8
ml_collections
numpy>=1.16.4
tensorflow>=2.5.0
tensorflow-datasets>=4.3.0
dm-tree>=0.1.6
ogb>=1.3.1
# graph_nets
git+git://github.com/deepmind/graph_nets.git@64771dff0d74ca8e77b1f1dcd5a7d26634356d61
scipy>=1.2.1
jaxlib>=0.1.57+cuda110
tqdm>=4.28.1
annoy>=1.0.3
# jraph
git+git://github.com/deepmind/jraph.git@80d2f9e4d82f841a71d56ba44fa9e8781e93ae1f
google-cloud-storage
+57
View File
@@ -0,0 +1,57 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/bash
set -e
set -x
while getopts ":r:" opt; do
case ${opt} in
r )
TASK_ROOT=$OPTARG
;;
\? )
echo "Usage: preprocess_data.sh -r <Task root directory>"
;;
: )
echo "Invalid option: $OPTARG requires an argument" 1>&2
;;
esac
done
shift $((OPTIND -1))
# Get this script's directory.
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
DATA_ROOT="${TASK_ROOT}"/data
PREPROCESSED_DIR="${DATA_ROOT}"/preprocessed
# Create preprocessed directory to move all files to it.
mkdir -p "${PREPROCESSED_DIR}"
# Run the CSR edge builder.
python "${SCRIPT_DIR}"/csr_builder.py --data_root="${DATA_ROOT}"
# Run the PCA feature builder.
python "${SCRIPT_DIR}"/pca_builder.py --data_root="${DATA_ROOT}"
# Run the neighbor-finder/fuser builder.
python "${SCRIPT_DIR}"/neighbor_builder.py --data_root="${DATA_ROOT}"
# Run the validation split generator.
python "${SCRIPT_DIR}"/generate_validation_splits.py \
--data_root="${DATA_ROOT}" \
--output_dir="${DATA_ROOT}/k_fold_splits"
+107
View File
@@ -0,0 +1,107 @@
#!/bin/bash
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
set -x
while getopts ":r:" opt; do
case ${opt} in
r )
TASK_ROOT=$OPTARG
;;
\? )
echo "Usage: run_pretrained_eval.sh -r <Task root directory>"
;;
: )
echo "Invalid option: $OPTARG requires an argument" 1>&2
;;
esac
done
shift $((OPTIND -1))
# Get this script's directory.
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
echo "
Note this script may take several days to run with default parameters on
a single machine. Reducing EPOCHS_TO_ENSEMBLE from 50 to < 5 should yield
slighly lower validation performance but possibly similar test performance.
"
echo "
Pre-requisites (See README):
* Python dependencies have been installed.
* pre-processed data is available in the task dir.
* pre-trained model weights are available in the task dir.
"
read -p "Press enter to continue"
# Can set this to "valid" or "test".
# On test the results will be ensembled for 10 models, and a submission file
# will be created.
# On validation, the results will be "gathered" for 10 models. This is because
# each model is trained on the train split + 90% of the validation split,
# and only evaluated on the remaining 10%, such that each validation paper
# is left out from training in exactly one of the 10 models.
SPLIT="test"
# We used 50 epochs in the submission to sample different subgraphs around
# each central node.
# For each of the 10 models in the ensemble, a single epoch takes about 1
# 10-25 minutes (depending on the GPU) on the test set, and about 2-3 minutes
# for the corresponding k-fold split of the validation set.
EPOCHS_TO_ENSEMBLE=50
DATA_ROOT=${TASK_ROOT}/data/
MODELS_ROOT=${TASK_ROOT}/models/
CHECKPOINT_DIR=${TASK_ROOT}/checkpoints/
OUTPUT_DIR=${TASK_ROOT}/predictions/
# We run two seeds for each model of the k=10 k-fold
# first seed group: [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
# second seed group: [110, 111, 112, 113, 114, 115, 116, 117, 118, 119]
# Thes are the seeds that was selected based on cross validation for each fold.
BEST_SEEDS=(100 111 102 113 104 105 106 107 108 109)
for K_FOLD_INDEX in {0..9}; do
SEED=${BEST_SEEDS[${K_FOLD_INDEX}]}
RESTORE_PATH=${MODELS_ROOT}/k${K_FOLD_INDEX}_seed${SEED}
echo "Running k=${K_FOLD_INDEX} on ${SPLIT} split using ${RESTORE_PATH}"
# This saves the predictions for the K_FOLD_INDEXd'th model in the k-fold to
# "config.experiment_kwargs.config.predictions_dir" for subsequent ensembling
python "${SCRIPT_DIR}"/experiment.py \
--jaxline_mode="eval" \
--config="${SCRIPT_DIR}"/config.py \
--config.one_off_evaluate=True \
--config.checkpoint_dir=${CHECKPOINT_DIR}/${K_FOLD_INDEX} \
--config.restore_path=${RESTORE_PATH} \
--config.experiment_kwargs.config.dataset_kwargs.data_root=${DATA_ROOT} \
--config.experiment_kwargs.config.dataset_kwargs.k_fold_split_id=${K_FOLD_INDEX} \
--config.experiment_kwargs.config.num_eval_iterations_to_ensemble=${EPOCHS_TO_ENSEMBLE} \
--config.experiment_kwargs.config.predictions_dir=${OUTPUT_DIR}/${K_FOLD_INDEX} \
--config.experiment_kwargs.config.eval.split=${SPLIT}
done
python "${SCRIPT_DIR}"/ensemble_predictions.py \
--split=${SPLIT} \
--data_root=${DATA_ROOT} \
--predictions_path=${OUTPUT_DIR} \
--output_path=${OUTPUT_DIR}
echo "Done"
+115
View File
@@ -0,0 +1,115 @@
#!/bin/bash
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
set -x
while getopts ":r:" opt; do
case ${opt} in
r )
TASK_ROOT=$OPTARG
;;
\? )
echo "Usage: run_training.sh -r <Task root directory>"
;;
: )
echo "Invalid option: $OPTARG requires an argument" 1>&2
;;
esac
done
shift $((OPTIND -1))
# Get this script's directory.
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
echo "
These scripts are provided for illustrative purposes. It is not practical for
actual training since it only uses a single machine, and likely requires
reducing the batch size and/or model size to fit on a single GPU.
For the actual submission we used training distributed in different ways:
* We used 4x Cloud TPU v4s to implement batch parallelism, so batch
size is effectively 8 times larger than the values in the config.
* We used separate CPUs to subsample neighborhoods around each central node, to
maximize TPU usage by always having a batch ready for the next iteration.
* We ran the online early-stopping evaluator on a separate machine with an
NVIDIA V100 GPU.
* We run identical replicas of all of the above for each of the 20 models
trained.
Using those mechanisms, training runs at ~2 steps per second, reaching 500k
steps in under 3 days.
"
echo "
Pre-requisites (See README):
* Python dependencies have been installed.
* pre-processed data is available in the task dir.
"
read -p "Press enter to continue"
# During early stopping seed/selection we ensembled 5 iterations.
EPOCHS_TO_ENSEMBLE=5
DATA_ROOT=${TASK_ROOT}/data/
CHECKPOINT_DIR=${TASK_ROOT}/checkpoints/
# We run two seeds for each model of the k=10 k-fold.
BASE_SEED=100
for SEED_OFFSET in 0 10; do
for K_FOLD_INDEX in {0..9}; do
MODEL_SEED=`expr ${BASE_SEED} + ${SEED_OFFSET} + ${K_FOLD_INDEX}`
SUFFIX=k${K_FOLD_INDEX}_seed${MODEL_SEED}
echo "Running k=${K_FOLD_INDEX} with init seed ${MODEL_SEED}"
# This runs training (each model is trained on train split + 90% of the
# validation split) with early stopping, storing both "latest" model and
# "best" early-stopped model at `--config.checkpoint_dir`.
# Models are early stopped based on accuracy of on 10% of the validation data
# (each K_FOLD_INDEX leaves a different 10% of data out from training) left
# out from training.
# Models are stored at the end of training. Intermediate models can also be
# stored while training by sending a SIGINT signal (Ctrl+C) which will not
# interrupt the training.
# It is possible to interrupt training using (Ctrl+\) and then continue
# passing the corresponding `--config.restore_path=${RESTORE_PATH}` which is
# stored when interrupting.
python "${SCRIPT_DIR}"/experiment.py \
--jaxline_mode="train_eval_multithreaded" \
--config="${SCRIPT_DIR}"/config.py \
--config.random_seed=${MODEL_SEED} \
--config.checkpoint_dir=${CHECKPOINT_DIR}/${SUFFIX} \
--config.experiment_kwargs.config.dataset_kwargs.data_root=${DATA_ROOT} \
--config.experiment_kwargs.config.dataset_kwargs.k_fold_split_id=${K_FOLD_INDEX} \
--config.experiment_kwargs.config.num_eval_iterations_to_ensemble=${EPOCHS_TO_ENSEMBLE} \
--config.experiment_kwargs.config.eval.split="valid"
done
done
# Each of the 20 (two for each value of k) jobs generate paths of the form:
# RESTORE_PATH=${--config.checkpoint_dir}/models/best/step_${STEP}_${TIMESTAMP}
# From each pair, the best one is selected (based on the validation accuracy
# as reported on the logs or in tensorboard events also stored at
# `${--config.checkpoint_dir}/eval`). These can then be used as the
# "RESTORE_PATHS" for `./run_pretrained_eval.sh`.
echo "Done"
+74
View File
@@ -0,0 +1,74 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Scheduling utilities."""
import jax.numpy as jnp
def apply_ema_decay(
ema_value: jnp.ndarray,
current_value: jnp.ndarray,
decay: jnp.ndarray,
) -> jnp.ndarray:
"""Implements EMA."""
return ema_value * decay + current_value * (1 - decay)
def ema_decay_schedule(
base_rate: jnp.ndarray,
step: jnp.ndarray,
total_steps: jnp.ndarray,
use_schedule: bool,
) -> jnp.ndarray:
"""Anneals decay rate to 1 with cosine schedule."""
if not use_schedule:
return base_rate
multiplier = _cosine_decay(step, total_steps, 1.)
return 1. - (1. - base_rate) * multiplier
def _cosine_decay(
global_step: jnp.ndarray,
max_steps: int,
initial_value: float,
) -> jnp.ndarray:
"""Simple implementation of cosine decay from TF1."""
global_step = jnp.minimum(global_step, max_steps).astype(jnp.float32)
cosine_decay_value = 0.5 * (1 + jnp.cos(jnp.pi * global_step / max_steps))
decayed_learning_rate = initial_value * cosine_decay_value
return decayed_learning_rate
def learning_schedule(
global_step: jnp.ndarray,
base_learning_rate: float,
total_steps: int,
warmup_steps: int,
use_schedule: bool,
) -> float:
"""Cosine learning rate scheduler."""
# Compute LR & Scaled LR
if not use_schedule:
return base_learning_rate
warmup_learning_rate = (
global_step.astype(jnp.float32) / int(warmup_steps) *
base_learning_rate if warmup_steps > 0 else base_learning_rate)
# Cosine schedule after warmup.
decay_learning_rate = _cosine_decay(global_step - warmup_steps,
total_steps - warmup_steps,
base_learning_rate)
return jnp.where(global_step < warmup_steps, warmup_learning_rate,
decay_learning_rate)
+50
View File
@@ -0,0 +1,50 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Split and save the train/valid/test indices.
Usage:
python3 split_and_save_indices.py --data_root="mag_data"
"""
import pathlib
from absl import app
from absl import flags
import numpy as np
import torch
Path = pathlib.Path
FLAGS = flags.FLAGS
flags.DEFINE_string('data_root', None, 'Data root directory')
def main(argv) -> None:
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
mag_directory = Path(FLAGS.data_root) / 'mag240m_kddcup2021'
raw_directory = mag_directory / 'raw'
raw_directory.parent.mkdir(parents=True, exist_ok=True)
splits_dict = torch.load(str(mag_directory / 'split_dict.pt'))
for key, indices in splits_dict.items():
np.save(str(raw_directory / f'{key}_idx.npy'), indices)
if __name__ == '__main__':
flags.mark_flag_as_required('root')
app.run(main)
+250
View File
@@ -0,0 +1,250 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for subsampling the MAG dataset."""
import collections
import jraph
import numpy as np
def get_or_sample_row(node_id: int,
nb_neighbours: int,
csr_matrix, remove_duplicates: bool):
"""Either obtain entire row or a subsampled set of neighbours."""
if node_id + 1 >= csr_matrix.indptr.shape[0]:
lo = 0
hi = 0
else:
lo = csr_matrix.indptr[node_id]
hi = csr_matrix.indptr[node_id + 1]
if lo == hi: # Skip empty neighbourhoods
neighbours = None
elif hi - lo <= nb_neighbours:
neighbours = csr_matrix.indices[lo:hi]
elif hi - lo < 5 * nb_neighbours: # For small surroundings, sample directly
nb_neighbours = min(nb_neighbours, hi - lo)
inds = lo + np.random.choice(hi - lo, size=(nb_neighbours,), replace=False)
neighbours = csr_matrix.indices[inds]
else: # Otherwise, do not slice -- sample indices instead
# To extend GraphSAGE ("uniform w/ replacement"), modify this call
inds = np.random.randint(lo, hi, size=(nb_neighbours,))
if remove_duplicates:
inds = np.unique(inds)
neighbours = csr_matrix.indices[inds]
return neighbours
def get_neighbours(node_id: int,
node_type: int,
neighbour_type: int,
nb_neighbours: int,
remove_duplicates: bool,
author_institution_csr, institution_author_csr,
author_paper_csr, paper_author_csr,
paper_paper_csr, paper_paper_transpose_csr):
"""Fetch the edge indices from one node to corresponding neighbour type."""
if node_type == 0 and neighbour_type == 0:
csr = paper_paper_transpose_csr # Citing
elif node_type == 0 and neighbour_type == 1:
csr = paper_author_csr
elif node_type == 0 and neighbour_type == 3:
csr = paper_paper_csr # Cited
elif node_type == 1 and neighbour_type == 0:
csr = author_paper_csr
elif node_type == 1 and neighbour_type == 2:
csr = author_institution_csr
elif node_type == 2 and neighbour_type == 1:
csr = institution_author_csr
else:
raise ValueError('Non-existent edge type requested')
return get_or_sample_row(node_id, nb_neighbours, csr, remove_duplicates)
def get_senders(neighbour_type: int,
sender_index,
paper_features):
"""Get the sender features from given neighbours."""
if neighbour_type == 0 or neighbour_type == 3:
sender_features = paper_features[sender_index]
elif neighbour_type == 1 or neighbour_type == 2:
sender_features = np.zeros((sender_index.shape[0],
paper_features.shape[1])) # Consider averages
else:
raise ValueError('Non-existent node type requested')
return sender_features
def make_edge_type_feature(node_type: int, neighbour_type: int):
edge_feats = np.zeros(7)
edge_feats[node_type] = 1.0
edge_feats[neighbour_type + 3] = 1.0
return edge_feats
def subsample_graph(paper_id: int,
author_institution_csr,
institution_author_csr,
author_paper_csr,
paper_author_csr,
paper_paper_csr,
paper_paper_transpose_csr,
max_nb_neighbours_per_type,
max_nodes=None,
max_edges=None,
paper_years=None,
remove_future_nodes=False,
deduplicate_nodes=False) -> jraph.GraphsTuple:
"""Subsample a graph around given paper ID."""
if paper_years is not None:
root_paper_year = paper_years[paper_id]
else:
root_paper_year = None
# Add the center node as "node-zero"
sub_nodes = [paper_id]
num_nodes_in_subgraph = 1
num_edges_in_subgraph = 0
reached_node_budget = False
reached_edge_budget = False
node_and_type_to_index_in_subgraph = dict()
node_and_type_to_index_in_subgraph[(paper_id, 0)] = 0
# Store all (integer) depths as an additional feature
depths = [0]
types = [0]
sub_edges = []
sub_senders = []
sub_receivers = []
# Store all unprocessed neighbours
# Each neighbour is stored as a 4-tuple (node_index in original graph,
# node_index in subsampled graph, type, number of hops away from source).
# TYPES: 0: paper, 1: author, 2: institution, 3: paper (for bidirectional)
neighbour_deque = collections.deque([(paper_id, 0, 0, 0)])
max_depth = len(max_nb_neighbours_per_type)
while neighbour_deque and not reached_edge_budget:
left_entry = neighbour_deque.popleft()
node_index, node_index_in_sampled_graph, node_type, node_depth = left_entry
# Expand from this node, to a node of related type
for neighbour_type in range(4):
if reached_edge_budget:
break # Budget may have been reached in previous type; break here.
nb_neighbours = max_nb_neighbours_per_type[node_depth][node_type][neighbour_type] # pylint:disable=line-too-long
# Only extend if we want to sample further in this edge type
if nb_neighbours > 0:
sampled_neighbors = get_neighbours(
node_index,
node_type,
neighbour_type,
nb_neighbours,
deduplicate_nodes,
author_institution_csr,
institution_author_csr,
author_paper_csr,
paper_author_csr,
paper_paper_csr,
paper_paper_transpose_csr,
)
if sampled_neighbors is not None:
if remove_future_nodes and root_paper_year is not None:
if neighbour_type in [0, 3]:
sampled_neighbors = [
x for x in sampled_neighbors
if paper_years[x] <= root_paper_year
]
if not sampled_neighbors:
continue
nb_neighbours = len(sampled_neighbors)
edge_feature = make_edge_type_feature(node_type, neighbour_type)
for neighbor_original_idx in sampled_neighbors:
# Key into dict of existing nodes using both node id and type.
neighbor_key = (neighbor_original_idx, neighbour_type % 3)
# Get existing idx in subgraph if it exists.
neighbor_subgraph_idx = node_and_type_to_index_in_subgraph.get(
neighbor_key, None)
if (not reached_node_budget and
(not deduplicate_nodes or neighbor_subgraph_idx is None)):
# If it does not exist already, or we are not deduplicating,
# just create a new node and update the dict.
neighbor_subgraph_idx = num_nodes_in_subgraph
node_and_type_to_index_in_subgraph[neighbor_key] = (
neighbor_subgraph_idx)
num_nodes_in_subgraph += 1
sub_nodes.append(neighbor_original_idx)
types.append(neighbour_type % 3)
depths.append(node_depth + 1)
if max_nodes is not None and num_nodes_in_subgraph >= max_nodes:
reached_node_budget = True
continue # Move to next neighbor which might already exist.
if node_depth < max_depth - 1:
# If the neighbours are to be further expanded, enqueue them.
# Expand only if the nodes did not already exist.
neighbour_deque.append(
(neighbor_original_idx, neighbor_subgraph_idx,
neighbour_type % 3, node_depth + 1))
# The neighbor id within graph is now fixed; just add edges.
if neighbor_subgraph_idx is not None:
# Either node existed before or was successfully added.
sub_senders.append(neighbor_subgraph_idx)
sub_receivers.append(node_index_in_sampled_graph)
sub_edges.append(edge_feature)
num_edges_in_subgraph += 1
if max_edges is not None and num_edges_in_subgraph >= max_edges:
reached_edge_budget = True
break # Break out of adding edges for this neighbor type
# Stitch the graph together
sub_nodes = np.array(sub_nodes, dtype=np.int32)
if sub_senders:
sub_senders = np.array(sub_senders, dtype=np.int32)
sub_receivers = np.array(sub_receivers, dtype=np.int32)
sub_edges = np.stack(sub_edges, axis=0)
else:
# Use empty arrays.
sub_senders = np.zeros([0], dtype=np.int32)
sub_receivers = np.zeros([0], dtype=np.int32)
sub_edges = np.zeros([0, 7])
# Finally, derive the sizes
sub_n_node = np.array([sub_nodes.shape[0]])
sub_n_edge = np.array([sub_senders.shape[0]])
assert sub_nodes.shape[0] == num_nodes_in_subgraph
assert sub_edges.shape[0] == num_edges_in_subgraph
if max_nodes is not None:
assert num_nodes_in_subgraph <= max_nodes
if max_edges is not None:
assert num_edges_in_subgraph <= max_edges
types = np.array(types)
depths = np.array(depths)
sub_nodes = {
'index': sub_nodes.astype(np.int32),
'type': types.astype(np.int16),
'depth': depths.astype(np.int16),
}
return jraph.GraphsTuple(nodes=sub_nodes,
edges=sub_edges.astype(np.float16),
senders=sub_senders.astype(np.int32),
receivers=sub_receivers.astype(np.int32),
globals=np.array([0], dtype=np.int16),
n_node=sub_n_node.astype(dtype=np.int32),
n_edge=sub_n_edge.astype(dtype=np.int32))
+110
View File
@@ -0,0 +1,110 @@
# DeepMind entry for PCQM4M-LSC
This repository contains DeepMind's entry to the [PCWM4M-LSC](https://ogb.stanford.edu/kddcup2021/pcqm4m/) (quantum chemistry)
track of the [OGB Large-Scale Challenge](https://ogb.stanford.edu/kddcup2021/)
(OGB-LSC).
For full details regarding this entry, please see our [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf).
## DeepMind PCQ Team ("Quantum")
(in alphabetical order)
- Ravichandra Addanki
- Peter Battaglia
- David Budden
- Andreea Deac
- Jonathan Godwin
- Alvaro Sanchez-Gonzalez
- Wai Lok Sibon Li
- Jacklynn Stott
- Shantanu Thakoor
- Petar Veličković
## Performance
Our final test set performance was achieved by pooling an ensemble of 20 models
(10 folds x 2 seeds). See [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf) for details.
Each model was trained for < 48 hours using 4x Google Cloud TPUv4 and 1x AMD
EPYC 7B12 64-core CPU @2.25GHz.
Inference takes < 6 hours on 1x NVIDIA V100 16GB GPU and 1x Intel Xeon Gold 6148
20-core CPU @2.40GHz.
# Running our model
## Setup
You can set up Python virtual environment (you might need to install the
`python3-venv` package first) with all needed dependencies inside the forked
`deepmind_research` repository using:
```bash
python3 -m venv /tmp/pcq_venv
source /tmp/pcq_venv/bin/activate
pip3 install --upgrade pip setuptools wheel
pip3 install -r ogb_lsc/pcq/requirements.txt
```
## Downloading data and model weights
All necessary data and pre-trained model weights can be downloaded by running
the following command.
This downloads about ~ 150 GB worth of model checkpoints.
```bash
python download_required_pcq_data.py --data_root=${HOME}/data/
```
## Generating Pre-processed features
All the additional features used in training
(k-fold splits and conformer position features) can be generated by running.
```bash
/bin/bash run_preprocessing.sh -r ${HOME}/data/pcq/
```
## Reproducing our final results
We have provided pre-trained weights of our final submission for convenience. To
reproduce our final results, please run `run_pretrained_eval.sh` as follows.
```bash
/bin/bash run_pretrained_eval.sh -r ${HOME}/data/pcq/
```
## Retraining our model
Disclaimer: This script is provided for illustrative purposes. It is not
practical for actual training since it only uses a single machine, and likely
requires reducing the batch size and/or model size to fit on a single GPU.
If you still want to train a model, please run `run_training.sh`. To simply
validate that the code is running correctly on your hardware setup, consider
setting `debug=True` in `config.py`, which trains a smaller model.
```bash
/bin/bash run_training.sh -r ${HOME}/data/pcq/
```
# Citation
To cite this work (together with our MAG240M-LSC entry):
```latex
@article{deepmind2021ogb,
author = {Ravichandra Addanki and Peter Battaglia and David Budden and Andreea
Deac and Jonathan Godwin and Thomas Keck and Wai Lok Sibon Li and Alvaro
Sanchez-Gonzalez and Jacklynn Stott and Shantanu Thakoor and Petar
Veli\v{c}kovi\'{c}},
title = {Large-scale graph representation learning with very deep GNNs and
self-supervision},
year = {2021},
}
```
Our technical report can be found [here](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf).
+135
View File
@@ -0,0 +1,135 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dynamic batching utilities."""
from typing import Generator, Iterator, Sequence, Tuple
import jax.tree_util as tree
import jraph
import numpy as np
_NUMBER_FIELDS = ("n_node", "n_edge", "n_graph")
def dynamically_batch(graphs_tuple_iterator: Iterator[jraph.GraphsTuple],
n_node: int, n_edge: int,
n_graph: int) -> Generator[jraph.GraphsTuple, None, None]:
"""Dynamically batches trees with `jraph.GraphsTuples` to `graph_batch_size`.
Elements of the `graphs_tuple_iterator` will be incrementally added to a batch
until the limits defined by `n_node`, `n_edge` and `n_graph` are reached. This
means each element yielded by this generator
For situations where you have variable sized data, it"s useful to be able to
have variable sized batches. This is especially the case if you have a loss
defined on the variable shaped element (for example, nodes in a graph).
Args:
graphs_tuple_iterator: An iterator of `jraph.GraphsTuples`.
n_node: The maximum number of nodes in a batch.
n_edge: The maximum number of edges in a batch.
n_graph: The maximum number of graphs in a batch.
Yields:
A `jraph.GraphsTuple` batch of graphs.
Raises:
ValueError: if the number of graphs is < 2.
RuntimeError: if the `graphs_tuple_iterator` contains elements which are not
`jraph.GraphsTuple`s.
RuntimeError: if a graph is found which is larger than the batch size.
"""
if n_graph < 2:
raise ValueError("The number of graphs in a batch size must be greater or "
f"equal to `2` for padding with graphs, got {n_graph}.")
valid_batch_size = (n_node - 1, n_edge, n_graph - 1)
accumulated_graphs = []
num_accumulated_nodes = 0
num_accumulated_edges = 0
num_accumulated_graphs = 0
for element in graphs_tuple_iterator:
element_nodes, element_edges, element_graphs = _get_graph_size(element)
if _is_over_batch_size(element, valid_batch_size):
graph_size = element_nodes, element_edges, element_graphs
graph_size = {k: v for k, v in zip(_NUMBER_FIELDS, graph_size)}
batch_size = {k: v for k, v in zip(_NUMBER_FIELDS, valid_batch_size)}
raise RuntimeError("Found graph bigger than batch size. Valid Batch "
f"Size: {batch_size}, Graph Size: {graph_size}")
if not accumulated_graphs:
# If this is the first element of the batch, set it and continue.
accumulated_graphs = [element]
num_accumulated_nodes = element_nodes
num_accumulated_edges = element_edges
num_accumulated_graphs = element_graphs
continue
else:
# Otherwise check if there is space for the graph in the batch:
if ((num_accumulated_graphs + element_graphs > n_graph - 1) or
(num_accumulated_nodes + element_nodes > n_node - 1) or
(num_accumulated_edges + element_edges > n_edge)):
# If there is, add it to the batch
batched_graph = _batch_np(accumulated_graphs)
yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph)
accumulated_graphs = [element]
num_accumulated_nodes = element_nodes
num_accumulated_edges = element_edges
num_accumulated_graphs = element_graphs
else:
# Otherwise, return the old batch and start a new batch.
accumulated_graphs.append(element)
num_accumulated_nodes += element_nodes
num_accumulated_edges += element_edges
num_accumulated_graphs += element_graphs
# We may still have data in batched graph.
if accumulated_graphs:
batched_graph = _batch_np(accumulated_graphs)
yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph)
def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple:
# Calculates offsets for sender and receiver arrays, caused by concatenating
# the nodes arrays.
offsets = np.cumsum(np.array([0] + [np.sum(g.n_node) for g in graphs[:-1]]))
def _map_concat(nests):
concat = lambda *args: np.concatenate(args)
return tree.tree_multimap(concat, *nests)
return jraph.GraphsTuple(
n_node=np.concatenate([g.n_node for g in graphs]),
n_edge=np.concatenate([g.n_edge for g in graphs]),
nodes=_map_concat([g.nodes for g in graphs]),
edges=_map_concat([g.edges for g in graphs]),
globals=_map_concat([g.globals for g in graphs]),
senders=np.concatenate([g.senders + o for g, o in zip(graphs, offsets)]),
receivers=np.concatenate(
[g.receivers + o for g, o in zip(graphs, offsets)]))
def _get_graph_size(graph: jraph.GraphsTuple) -> Tuple[int, int, int]:
n_node = np.sum(graph.n_node)
n_edge = len(graph.senders)
n_graph = len(graph.n_node)
return n_node, n_edge, n_graph
def _is_over_batch_size(
graph: jraph.GraphsTuple,
graph_batch_size: Tuple[int, int, int],
) -> bool:
graph_size = _get_graph_size(graph)
return any([x > y for x, y in zip(graph_size, graph_batch_size)])
+130
View File
@@ -0,0 +1,130 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experiment config for PCQM4M-LSC entry."""
from jaxline import base_config
from ml_collections import config_dict
def get_config(debug: bool = False) -> config_dict.ConfigDict:
"""Get Jaxline experiment config."""
config = base_config.get_base_config()
# E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below)
config.restore_path = config_dict.placeholder(str)
training_batch_size = 64
eval_batch_size = 64
## Experiment config.
loss_config_name = 'RegressionLossConfig'
loss_kwargs = dict(
exponent=1., # 2 for l2 loss, 1 for l1 loss, etc...
)
dataset_config = dict(
data_root=config_dict.placeholder(str),
augment_with_random_mirror_symmetry=True,
k_fold_split_id=config_dict.placeholder(int),
num_k_fold_splits=config_dict.placeholder(int),
# Options: "in" or "out".
# Filter=in would keep the samples with nans in the conformer features.
# Filter=out would keep the samples with no NaNs anywhere in the conformer
# features.
filter_in_or_out_samples_with_nans_in_conformers=(
config_dict.placeholder(str)),
cached_conformers_file=config_dict.placeholder(str))
model_config = dict(
mlp_hidden_size=512,
mlp_layers=2,
latent_size=512,
use_layer_norm=False,
num_message_passing_steps=32,
shared_message_passing_weights=False,
mask_padding_graph_at_every_step=True,
loss_config_name=loss_config_name,
loss_kwargs=loss_kwargs,
processor_mode='resnet',
global_reducer='sum',
node_reducer='sum',
dropedge_rate=0.1,
dropnode_rate=0.1,
aux_multiplier=0.1,
add_relative_distance=True,
add_relative_displacement=True,
add_absolute_positions=False,
position_normalization=2.,
relative_displacement_normalization=1.,
ignore_globals=False,
ignore_globals_from_final_layer_for_predictions=True,
)
if debug:
# Make network smaller.
model_config.update(dict(
mlp_hidden_size=32,
mlp_layers=1,
latent_size=32,
num_message_passing_steps=1))
config.experiment_kwargs = config_dict.ConfigDict(
dict(
config=dict(
debug=debug,
predictions_dir=config_dict.placeholder(str),
ema=True,
ema_decay=0.9999,
sample_random=0.05,
optimizer=dict(
name='adam',
optimizer_kwargs=dict(b1=.9, b2=.95),
lr_schedule=dict(
warmup_steps=int(5e4),
decay_steps=int(5e5),
init_value=1e-5,
peak_value=1e-4,
end_value=0.,
),
),
model=model_config,
dataset_config=dataset_config,
# As a rule of thumb, use the following statistics:
# Avg. # nodes in graph: 16.
# Avg. # edges in graph: 40.
training=dict(
dynamic_batch_size={
'n_node': 256 if debug else 16 * training_batch_size,
'n_edge': 512 if debug else 40 * training_batch_size,
'n_graph': 2 if debug else training_batch_size,
},),
evaluation=dict(
split='valid',
dynamic_batch_size=dict(
n_node=256 if debug else 16 * eval_batch_size,
n_edge=512 if debug else 40 * eval_batch_size,
n_graph=2 if debug else eval_batch_size,
)))))
## Training loop config.
config.training_steps = int(5e6)
config.checkpoint_dir = '/tmp/checkpoint/pcq/'
config.train_checkpoint_all_hosts = False
config.save_checkpoint_interval = 300
config.log_train_data_interval = 60
config.log_tensors_interval = 60
config.best_model_eval_metric = 'mae'
config.best_model_eval_metric_higher_is_better = False
return config
+324
View File
@@ -0,0 +1,324 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Conformer utilities."""
import copy
from typing import List, Optional
from absl import logging
import numpy as np
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
import tensorflow.compat.v2 as tf
def generate_conformers(
molecule: Chem.rdchem.Mol,
max_num_conformers: int,
*,
random_seed: int = -1,
prune_rms_thresh: float = -1.0,
max_iter: int = -1,
fallback_to_random: bool = False,
) -> Chem.rdchem.Mol:
"""Generates conformers for a given molecule.
Args:
molecule: molecular representation of the compound.
max_num_conformers: maximum number of conformers to generate. If pruning is
done, the returned number of conformers is not guaranteed to match
max_num_conformers.
random_seed: random seed to use for conformer generation.
prune_rms_thresh: RMSD threshold which allows to prune conformers that are
too similar.
max_iter: Maximum number of iterations to perform when optimising MMFF force
field. If set to <= 0, energy optimisation is not performed.
fallback_to_random: if conformers cannot be obtained, use random coordinates
to initialise.
Returns:
Copy of a `molecule` with added hydrogens. The returned molecule contains
force field-optimised conformers. The number of conformers is guaranteed to
be <= max_num_conformers.
"""
mol = copy.deepcopy(molecule)
mol = Chem.AddHs(mol)
mol = _embed_conformers(
mol,
max_num_conformers,
random_seed,
prune_rms_thresh,
fallback_to_random,
use_random=False)
if max_iter > 0:
mol_with_conformers = _minimize_by_mmff(mol, max_iter)
if mol_with_conformers is None:
mol_with_conformers = _minimize_by_uff(mol, max_iter)
else:
mol_with_conformers = mol
# Aligns conformations in a molecule to each other using the first
# conformation as the reference.
AllChem.AlignMolConformers(mol_with_conformers)
# We remove hydrogens to keep the number of atoms consistent with the graph
# nodes.
mol_with_conformers = Chem.RemoveHs(mol_with_conformers)
return mol_with_conformers
def atom_to_feature_vector(
atom: rdkit.Chem.rdchem.Atom,
conformer: Optional[np.ndarray] = None,
) -> List[float]:
"""Converts rdkit atom object to feature list of indices.
Args:
atom: rdkit atom object.
conformer: Generated conformers. Returns -1 values if set to None.
Returns:
List containing positions (x, y, z) of each atom from the conformer.
"""
if conformer:
pos = conformer.GetAtomPosition(atom.GetIdx())
return [pos.x, pos.y, pos.z]
return [np.nan, np.nan, np.nan]
def compute_conformer(smile: str, max_iter: int = -1) -> np.ndarray:
"""Computes conformer.
Args:
smile: Smile string.
max_iter: Maximum number of iterations to perform when optimising MMFF force
field. If set to <= 0, energy optimisation is not performed.
Returns:
A tuple containing index, fingerprint and conformer.
Raises:
RuntimeError: If unable to convert smile string to RDKit mol.
"""
mol = rdkit.Chem.MolFromSmiles(smile)
if not mol:
raise RuntimeError('Unable to convert smile to molecule: %s' % smile)
conformer_failed = False
try:
mol = generate_conformers(
mol,
max_num_conformers=1,
random_seed=45,
prune_rms_thresh=0.01,
max_iter=max_iter)
except IOError as e:
logging.exception('Failed to generate conformers for %s . IOError %s.',
smile, e)
conformer_failed = True
except ValueError:
logging.error('Failed to generate conformers for %s . ValueError', smile)
conformer_failed = True
except: # pylint: disable=bare-except
logging.error('Failed to generate conformers for %s.', smile)
conformer_failed = True
atom_features_list = []
conformer = None if conformer_failed else list(mol.GetConformers())[0]
for atom in mol.GetAtoms():
atom_features_list.append(atom_to_feature_vector(atom, conformer))
conformer_features = np.array(atom_features_list, dtype=np.float32)
return conformer_features
def get_random_rotation_matrix(include_mirror_symmetry: bool) -> tf.Tensor:
"""Returns a single random rotation matrix."""
rotation_matrix = _get_random_rotation_3d()
if include_mirror_symmetry:
random_mirror_symmetry = _get_random_mirror_symmetry()
rotation_matrix = tf.matmul(rotation_matrix, random_mirror_symmetry)
return rotation_matrix
def rotate(vectors: tf.Tensor, rotation_matrix: tf.Tensor) -> tf.Tensor:
"""Batch of vectors on a single rotation matrix."""
return tf.matmul(vectors, rotation_matrix)
def _embed_conformers(
molecule: Chem.rdchem.Mol,
max_num_conformers: int,
random_seed: int,
prune_rms_thresh: float,
fallback_to_random: bool,
*,
use_random: bool = False,
) -> Chem.rdchem.Mol:
"""Embeds conformers into a copy of a molecule.
If random coordinates allowed, tries not to use random coordinates at first,
and uses random only if fails.
Args:
molecule: molecular representation of the compound.
max_num_conformers: maximum number of conformers to generate. If pruning is
done, the returned number of conformers is not guaranteed to match
max_num_conformers.
random_seed: random seed to use for conformer generation.
prune_rms_thresh: RMSD threshold which allows to prune conformers that are
too similar.
fallback_to_random: if conformers cannot be obtained, use random coordinates
to initialise.
*:
use_random: Use random coordinates. Shouldn't be set by any caller except
this function itself.
Returns:
A copy of a molecule with embedded conformers.
Raises:
ValueError: if conformers cannot be obtained for a given molecule.
"""
mol = copy.deepcopy(molecule)
# Obtains parameters for conformer generation.
# In particular, ETKDG is experimental-torsion basic knowledge distance
# geometry, which allows to randomly generate an initial conformation that
# satisfies various geometric constraints such as lower and upper bounds on
# the distances between atoms.
params = AllChem.ETKDGv3()
params.randomSeed = random_seed
params.pruneRmsThresh = prune_rms_thresh
params.numThreads = -1
params.useRandomCoords = use_random
conf_ids = AllChem.EmbedMultipleConfs(mol, max_num_conformers, params)
if not conf_ids:
if not fallback_to_random or use_random:
raise ValueError('Cant get conformers')
return _embed_conformers(
mol,
max_num_conformers,
random_seed,
prune_rms_thresh,
fallback_to_random,
use_random=True)
return mol
def _minimize_by_mmff(
molecule: Chem.rdchem.Mol,
max_iter: int,
) -> Optional[Chem.rdchem.Mol]:
"""Minimizes forcefield for conformers using MMFF algorithm.
Args:
molecule: a datastructure containing conformers.
max_iter: number of maximum iterations to use when optimising force field.
Returns:
A copy of a `molecule` containing optimised conformers; or None if MMFF
cannot be performed.
"""
molecule_props = AllChem.MMFFGetMoleculeProperties(molecule)
if molecule_props is None:
return None
mol = copy.deepcopy(molecule)
for conf_id in range(mol.GetNumConformers()):
ff = AllChem.MMFFGetMoleculeForceField(
mol, molecule_props, confId=conf_id, ignoreInterfragInteractions=False)
ff.Initialize()
# minimises a conformer within a mol in place.
ff.Minimize(max_iter)
return mol
def _minimize_by_uff(
molecule: Chem.rdchem.Mol,
max_iter: int,
) -> Chem.rdchem.Mol:
"""Minimizes forcefield for conformers using UFF algorithm.
Args:
molecule: a datastructure containing conformers.
max_iter: number of maximum iterations to use when optimising force field.
Returns:
A copy of a `molecule` containing optimised conformers.
"""
mol = copy.deepcopy(molecule)
conf_ids = range(mol.GetNumConformers())
for conf_id in conf_ids:
ff = AllChem.UFFGetMoleculeForceField(mol, confId=conf_id)
ff.Initialize()
# minimises a conformer within a mol in place.
ff.Minimize(max_iter)
return mol
def _get_symmetry_rotation_matrix(sign: tf.Tensor) -> tf.Tensor:
"""Returns the 2d/3d matrix for mirror symmetry."""
zero = tf.zeros_like(sign)
one = tf.ones_like(sign)
# pylint: disable=bad-whitespace,bad-continuation
rot = [sign, zero, zero,
zero, one, zero,
zero, zero, one]
# pylint: enable=bad-whitespace,bad-continuation
shape = (3, 3)
rot = tf.stack(rot, axis=-1)
rot = tf.reshape(rot, shape)
return rot
def _quaternion_to_rotation_matrix(quaternion: tf.Tensor) -> tf.Tensor:
"""Converts a batch of quaternions to a batch of rotation matrices."""
q0 = quaternion[0]
q1 = quaternion[1]
q2 = quaternion[2]
q3 = quaternion[3]
r00 = 2 * (q0 * q0 + q1 * q1) - 1
r01 = 2 * (q1 * q2 - q0 * q3)
r02 = 2 * (q1 * q3 + q0 * q2)
r10 = 2 * (q1 * q2 + q0 * q3)
r11 = 2 * (q0 * q0 + q2 * q2) - 1
r12 = 2 * (q2 * q3 - q0 * q1)
r20 = 2 * (q1 * q3 - q0 * q2)
r21 = 2 * (q2 * q3 + q0 * q1)
r22 = 2 * (q0 * q0 + q3 * q3) - 1
matrix = tf.stack([r00, r01, r02,
r10, r11, r12,
r20, r21, r22], axis=-1)
return tf.reshape(matrix, [3, 3])
def _get_random_rotation_3d() -> tf.Tensor:
random_quaternions = tf.random.normal(
shape=[4], dtype=tf.float32)
random_quaternions /= tf.linalg.norm(
random_quaternions, axis=-1, keepdims=True)
return _quaternion_to_rotation_matrix(random_quaternions)
def _get_random_mirror_symmetry() -> tf.Tensor:
random_0_1 = tf.random.uniform(
shape=(), minval=0, maxval=2, dtype=tf.int32)
random_signs = tf.cast((2 * random_0_1) - 1, tf.float32)
return _get_symmetry_rotation_matrix(random_signs)
+391
View File
@@ -0,0 +1,391 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset utilities."""
from typing import List, Optional
import jax
import jraph
from ml_collections import config_dict
import numpy as np
from ogb import utils
from ogb.utils import features
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import tree
# pylint: disable=g-bad-import-order
# pytype: disable=import-error
import batching_utils
import conformer_utils
import datasets
def build_dataset_iterator(
data_root: str,
split: str,
dynamic_batch_size_config: config_dict.ConfigDict,
sample_random: float,
cached_conformers_file: str,
debug: bool = False,
is_training: bool = True,
augment_with_random_mirror_symmetry: bool = False,
positions_noise_std: Optional[float] = None,
k_fold_split_id: Optional[int] = None,
num_k_fold_splits: Optional[int] = None,
filter_in_or_out_samples_with_nans_in_conformers: Optional[str] = None,
):
"""Returns an iterator over Batches from the dataset."""
if debug:
max_items_to_read_from_dataset = 10
prefetch_buffer_size = 1
shuffle_buffer_size = 1
else:
max_items_to_read_from_dataset = -1 # < 0 means no limit.
prefetch_buffer_size = 64
# It can take a while to fill the shuffle buffer with k fold splits.
shuffle_buffer_size = 128 if k_fold_split_id is None else int(1e6)
num_local_devices = jax.local_device_count()
# Load all smile strings.
indices, smiles, labels = _load_smiles(
data_root,
split,
k_fold_split_id=k_fold_split_id,
num_k_fold_splits=num_k_fold_splits)
if debug:
indices = indices[:100]
smiles = smiles[:100]
labels = labels[:100]
# Generate all conformer features from smile strings ahead of time.
# This gives us a boost from multi-parallelism as opposed to doing it
# online.
conformers = _load_conformers(indices, smiles, cached_conformers_file)
data_generator = (
lambda: _get_pcq_graph_generator(indices, smiles, labels, conformers))
# Create a dataset yielding graphs from smile strings.
example = next(data_generator())
signature_from_example = tree.map_structure(_numpy_to_tensor_spec, example)
ds = tf.data.Dataset.from_generator(
data_generator, output_signature=signature_from_example)
ds = ds.take(max_items_to_read_from_dataset)
ds = ds.cache()
if is_training:
ds = ds.shuffle(shuffle_buffer_size)
# Apply transformations.
def map_fn(graph, conformer_positions):
graph = _maybe_one_hot_atoms_with_noise(
graph, is_training=is_training, sample_random=sample_random)
# Add conformer features.
graph = _add_conformer_features(
graph,
conformer_positions,
augment_with_random_mirror_symmetry=augment_with_random_mirror_symmetry,
noise_std=positions_noise_std,
is_training=is_training,
)
return _downcast_ints(graph)
ds = ds.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)
if filter_in_or_out_samples_with_nans_in_conformers:
if filter_in_or_out_samples_with_nans_in_conformers not in ("in", "out"):
raise ValueError(
"Unknown value specified for the argument "
"`filter_in_or_out_samples_with_nans_in_conformers`: %s" %
filter_in_or_out_samples_with_nans_in_conformers)
filter_fn = _get_conformer_filter(
with_nans=(filter_in_or_out_samples_with_nans_in_conformers == "in"))
ds = ds.filter(filter_fn)
if is_training:
ds = ds.shard(jax.process_count(), jax.process_index())
ds = ds.repeat()
ds = ds.prefetch(prefetch_buffer_size)
it = tfds.as_numpy(ds)
# Dynamic batching.
batched_gen = batching_utils.dynamically_batch(
it,
n_node=dynamic_batch_size_config.n_node + 1,
n_edge=dynamic_batch_size_config.n_edge,
n_graph=dynamic_batch_size_config.n_graph + 1,
)
if is_training:
# Stack `num_local_devices` of batches together for pmap updates.
batch_size = num_local_devices
def _batch(l):
assert l
return tree.map_structure(lambda *l: np.stack(l, axis=0), *l)
def batcher_fn():
batch = []
for sample in batched_gen:
batch.append(sample)
if len(batch) == batch_size:
yield _batch(batch)
batch = []
if batch:
yield _batch(batch)
for sample in batcher_fn():
yield sample
else:
for sample in batched_gen:
yield sample
def _get_conformer_filter(with_nans: bool):
"""Selects a conformer filter to apply.
Args:
with_nans: Filter only selects samples with NaNs in conformer features.
Else, selects samples without any NaNs in conformer features.
Returns:
A function that can be used with tf.data.Dataset.filter().
Raises:
ValueError:
If the input graph to the filter has no conformer features to filter.
"""
def _filter(graph: jraph.GraphsTuple) -> tf.Tensor:
if ("positions" not in graph.nodes) or (
"positions_targets" not in graph.nodes) or (
"positions_nan_mask" not in graph.globals):
raise ValueError("Conformer features not available to filter.")
any_nan = tf.logical_not(tf.squeeze(graph.globals["positions_nan_mask"]))
return any_nan if with_nans else tf.logical_not(any_nan)
return _filter
def _numpy_to_tensor_spec(arr: np.ndarray) -> tf.TensorSpec:
if not isinstance(arr, np.ndarray):
return tf.TensorSpec([],
dtype=tf.int32 if isinstance(arr, int) else tf.float32)
elif arr.shape:
return tf.TensorSpec((None,) + arr.shape[1:], arr.dtype)
else:
return tf.TensorSpec([], arr.dtype)
def _sample_uniform_categorical(num: int, size: int) -> tf.Tensor:
return tf.random.categorical(tf.math.log([[1 / size] * size]), num)[0]
@jax.curry(jax.tree_map)
def _downcast_ints(x):
if x.dtype == tf.int64:
return tf.cast(x, tf.int32)
return x
def _one_hot_atoms(atoms: tf.Tensor) -> tf.Tensor:
vocab_sizes = features.get_atom_feature_dims()
one_hots = []
for i in range(atoms.shape[1]):
one_hots.append(tf.one_hot(atoms[:, i], vocab_sizes[i], dtype=tf.float32))
return tf.concat(one_hots, axis=-1)
def _sample_one_hot_atoms(atoms: tf.Tensor) -> tf.Tensor:
vocab_sizes = features.get_atom_feature_dims()
one_hots = []
num_atoms = tf.shape(atoms)[0]
for i in range(atoms.shape[1]):
sampled_category = _sample_uniform_categorical(num_atoms, vocab_sizes[i])
one_hots.append(
tf.one_hot(sampled_category, vocab_sizes[i], dtype=tf.float32))
return tf.concat(one_hots, axis=-1)
def _one_hot_bonds(bonds: tf.Tensor) -> tf.Tensor:
vocab_sizes = features.get_bond_feature_dims()
one_hots = []
for i in range(bonds.shape[1]):
one_hots.append(tf.one_hot(bonds[:, i], vocab_sizes[i], dtype=tf.float32))
return tf.concat(one_hots, axis=-1)
def _sample_one_hot_bonds(bonds: tf.Tensor) -> tf.Tensor:
vocab_sizes = features.get_bond_feature_dims()
one_hots = []
num_bonds = tf.shape(bonds)[0]
for i in range(bonds.shape[1]):
sampled_category = _sample_uniform_categorical(num_bonds, vocab_sizes[i])
one_hots.append(
tf.one_hot(sampled_category, vocab_sizes[i], dtype=tf.float32))
return tf.concat(one_hots, axis=-1)
def _maybe_one_hot_atoms_with_noise(
x,
is_training: bool,
sample_random: float,
):
"""One hot atoms with noise."""
gt_nodes = _one_hot_atoms(x.nodes)
gt_edges = _one_hot_bonds(x.edges)
if is_training:
num_nodes = tf.shape(x.nodes)[0]
sample_node_or_not = tf.random.uniform([num_nodes],
maxval=1) < sample_random
nodes = tf.where(
tf.expand_dims(sample_node_or_not, axis=-1),
_sample_one_hot_atoms(x.nodes), gt_nodes)
num_edges = tf.shape(x.edges)[0]
sample_edges_or_not = tf.random.uniform([num_edges],
maxval=1) < sample_random
edges = tf.where(
tf.expand_dims(sample_edges_or_not, axis=-1),
_sample_one_hot_bonds(x.edges), gt_edges)
else:
nodes = gt_nodes
edges = gt_edges
return x._replace(
nodes={
"atom_one_hots_targets": gt_nodes,
"atom_one_hots": nodes,
},
edges={
"bond_one_hots_targets": gt_edges,
"bond_one_hots": edges
})
def _load_smiles(
data_root: str,
split: str,
k_fold_split_id: int,
num_k_fold_splits: int,
):
"""Loads smiles trings for the input split."""
if split == "test" or k_fold_split_id is None:
indices = datasets.load_splits()[split]
elif split == "train":
indices = datasets.load_all_except_kth_fold_indices(
data_root, k_fold_split_id, num_k_fold_splits)
else:
assert split == "valid"
indices = datasets.load_kth_fold_indices(data_root, k_fold_split_id)
smiles_and_labels = datasets.load_smile_strings(with_labels=True)
smiles, labels = list(zip(*smiles_and_labels))
return indices, [smiles[i] for i in indices], [labels[i] for i in indices]
def _convert_ogb_graph_to_graphs_tuple(ogb_graph):
"""Converts an OGB Graph to a GraphsTuple."""
senders = ogb_graph["edge_index"][0]
receivers = ogb_graph["edge_index"][1]
edges = ogb_graph["edge_feat"]
nodes = ogb_graph["node_feat"]
n_node = np.array([ogb_graph["num_nodes"]])
n_edge = np.array([len(senders)])
graph = jraph.GraphsTuple(
nodes=nodes,
edges=edges,
senders=senders,
receivers=receivers,
n_node=n_node,
n_edge=n_edge,
globals=None)
return tree.map_structure(lambda x: x if x is not None else np.array(0.),
graph)
def _load_conformers(indices: List[int],
smiles: List[str],
cached_conformers_file: str):
"""Loads conformers."""
smile_to_conformer = datasets.load_cached_conformers(cached_conformers_file)
conformers = []
for graph_idx, smile in zip(indices, smiles):
del graph_idx # Unused.
if smile not in smile_to_conformer:
raise KeyError("Cache did not have conformer entry for the smile %s" %
str(smile))
conformers.append(dict(conformer=smile_to_conformer[smile]))
return conformers
def _add_conformer_features(
graph,
conformer_features,
augment_with_random_mirror_symmetry: bool,
noise_std: float,
is_training: bool,
):
"""Adds conformer features."""
if not isinstance(graph.nodes, dict):
raise ValueError("Expected a dict type for `graph.nodes`.")
# Remove mean position to center around a canonical origin.
positions = conformer_features["conformer"]
# NaN's appear in ~0.13% of training, 0.104% of validation and 0.16% of test
# nodes.
# See this colab: http://shortn/_6UcuosxY7x.
nan_mask = tf.reduce_any(tf.math.is_nan(positions))
positions = tf.where(nan_mask, tf.constant(0., positions.dtype), positions)
positions -= tf.reduce_mean(positions, axis=0, keepdims=True)
# Optionally augment with a random rotation.
if is_training:
rot_mat = conformer_utils.get_random_rotation_matrix(
augment_with_random_mirror_symmetry)
positions = conformer_utils.rotate(positions, rot_mat)
positions_targets = positions
# Optionally add noise to the positions.
if noise_std and is_training:
positions = tf.random.normal(tf.shape(positions), positions, noise_std)
return graph._replace(
nodes=dict(
positions=positions,
positions_targets=positions_targets,
**graph.nodes),
globals={
"positions_nan_mask":
tf.expand_dims(tf.logical_not(nan_mask), axis=0),
**(graph.globals if isinstance(graph.globals, dict) else {})
})
def _get_pcq_graph_generator(indices, smiles, labels, conformers):
"""Returns a generator to yield graph."""
for idx, smile, conformer_positions, label in zip(indices, smiles, conformers,
labels):
graph = utils.smiles2graph(smile)
graph = _convert_ogb_graph_to_graphs_tuple(graph)
graph = graph._replace(
globals={
"target": np.array([label], dtype=np.float32),
"graph_index": np.array([idx], dtype=np.int32),
**(graph.globals if isinstance(graph.globals, dict) else {})
})
yield graph, conformer_positions
+83
View File
@@ -0,0 +1,83 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PCQM4M-LSC datasets."""
import functools
import pickle
from typing import Dict, List, Tuple, Union
import numpy as np
from ogb import lsc
NUM_VALID_SAMPLES = 380_670
NUM_TEST_SAMPLES = 377_423
NORMALIZE_TARGET_MEAN = 5.690944545356371
NORMALIZE_TARGET_STD = 1.1561347795107815
def load_splits() -> Dict[str, List[int]]:
"""Loads dataset splits."""
dataset = _get_pcq_dataset(only_smiles=True)
return dataset.get_idx_split()
def load_kth_fold_indices(data_root: str, k_fold_split_id: int) -> List[int]:
"""Loads k-th fold indices."""
fname = f"{data_root}/k_fold_splits/{k_fold_split_id}.pkl"
return list(map(int, _load_pickle(fname)))
def load_all_except_kth_fold_indices(data_root: str, k_fold_split_id: int,
num_k_fold_splits: int) -> List[int]:
"""Loads indices except for the kth fold."""
if k_fold_split_id is None:
raise ValueError("Expected integer value for `k_fold_split_id`.")
indices = []
for index in range(num_k_fold_splits):
if index != k_fold_split_id:
indices += load_kth_fold_indices(data_root, index)
return indices
def load_smile_strings(
with_labels=False) -> List[Union[str, Tuple[str, np.ndarray]]]:
"""Loads the smile strings in the PCQ dataset."""
dataset = _get_pcq_dataset(only_smiles=True)
smiles = []
for i in range(len(dataset)):
smile, label = dataset[i]
if with_labels:
smiles.append((smile, label))
else:
smiles.append(smile)
return smiles
@functools.lru_cache()
def load_cached_conformers(cached_fname: str) -> Dict[str, np.ndarray]:
"""Returns cached dict mapping smile strings to conformer features."""
return _load_pickle(cached_fname)
@functools.lru_cache()
def _get_pcq_dataset(only_smiles: bool):
return lsc.PCQM4MDataset(only_smiles=only_smiles)
def _load_pickle(fname: str):
with open(fname, "rb") as f:
return pickle.load(f)
+99
View File
@@ -0,0 +1,99 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Download data required for training and evaluating models."""
import pathlib
from absl import app
from absl import flags
from absl import logging
from google.cloud import storage
Path = pathlib.Path
_BUCKET_NAME = 'deepmind-ogb-lsc'
_MAX_DOWNLOAD_ATTEMPTS = 5
FLAGS = flags.FLAGS
flags.DEFINE_enum('payload', None, ['data', 'models'],
'Download "data" or "models"?')
flags.DEFINE_string('task_root', None, 'Local task root directory')
# GCS_DATA_ROOT = Path('pcq/data')
# GCS_MODEL_ROOT = Path('pcq/models')
DATA_RELATIVE_PATHS = [
'raw/data.csv.gz', 'preprocessed/smile_to_conformer.pkl',
'k_fold_splits'
]
class DataCorruptionError(Exception):
pass
def _get_gcs_root():
return Path('pcq') / FLAGS.payload
def _get_gcs_bucket():
storage_client = storage.Client.create_anonymous_client()
return storage_client.bucket(_BUCKET_NAME)
def _write_blob_to_destination(blob, task_root, ignore_existing=True):
"""Write the blob."""
logging.info("Copying blob: '%s'", blob.name)
destination_path = Path(task_root) / Path(*Path(blob.name).parts[1:])
logging.info(" ... to: '%s'", str(destination_path))
if ignore_existing and destination_path.exists():
return
destination_path.parent.mkdir(parents=True, exist_ok=True)
checksum = 'crc32c'
for attempt in range(_MAX_DOWNLOAD_ATTEMPTS):
try:
blob.download_to_filename(destination_path.as_posix(), checksum=checksum)
except storage.client.resumable_media.common.DataCorruption:
pass
else:
break
else:
raise DataCorruptionError(f"Checksum ('{checksum}') for {blob.name} failed "
f'after {attempt + 1} attempts')
def main(unused_argv):
bucket = _get_gcs_bucket()
if FLAGS.payload == 'data':
relative_paths = DATA_RELATIVE_PATHS
else:
relative_paths = (None,)
for relative_path in relative_paths:
if relative_path is None:
relative_path = str(_get_gcs_root())
else:
relative_path = str(_get_gcs_root() / relative_path)
logging.info("Copying relative path: '%s'", relative_path)
blobs = bucket.list_blobs(prefix=relative_path)
for blob in blobs:
_write_blob_to_destination(blob, FLAGS.task_root)
if __name__ == '__main__':
flags.mark_flag_as_required('payload')
flags.mark_flag_as_required('task_root')
app.run(main)
+258
View File
@@ -0,0 +1,258 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to generate ensembled PCQ test predictions."""
import collections
import os
import pathlib
from typing import List, NamedTuple
from absl import app
from absl import flags
from absl import logging
import dill
import numpy as np
from ogb import lsc
# pylint: disable=g-bad-import-order
# pytype: disable=import-error
import datasets
_NUM_SEEDS = 2
_CLIP_VALUE = 20.
_NUM_KFOLD_SPLITS = 10
_SEED_START = flags.DEFINE_integer(
'seed_start', 42, 'Initial seed for the list of ensemble models.')
_CONFORMER_PATH = flags.DEFINE_string(
'conformer_path', None, 'Path to conformer predictions.', required=True)
_NON_CONFORMER_PATH = flags.DEFINE_string(
'non_conformer_path',
None,
'Path to non-conformer predictions.',
required=True)
_OUTPUT_PATH = flags.DEFINE_string('output_path', None, 'Output path.')
_SPLIT = flags.DEFINE_enum('split', 'test', ['test', 'valid'],
'Split: valid or test.')
class _Predictions(NamedTuple):
predictions: np.ndarray
indices: np.ndarray
def _load_dill(fname) -> bytes:
with open(fname, 'rb') as f:
return dill.load(f)
def _sort_by_indices(predictions: _Predictions) -> _Predictions:
order = np.argsort(predictions.indices)
return _Predictions(
predictions=predictions.predictions[order],
indices=predictions.indices[order])
def load_predictions(path: str, split: str) -> _Predictions:
"""Load written prediction file."""
if len(os.listdir(path)) != 1:
raise ValueError('Prediction directory must have exactly '
'one prediction sub-directory: %s' % path)
prediction_subdir = os.listdir(path)[0]
return _Predictions(*_load_dill(f'{path}/{prediction_subdir}/{split}.dill'))
def mean_mae_distance(x, y):
return np.abs(x - y).mean()
def _load_valid_labels() -> np.ndarray:
labels = [label for _, label in datasets.load_smile_strings(with_labels=True)]
return np.array([labels[i] for i in datasets.load_splits()['valid']])
def evaluate_valid_predictions(ensembled_predictions: _Predictions):
"""Evaluates the predictions on the validation set."""
ensembled_predictions = _sort_by_indices(ensembled_predictions)
evaluator = lsc.PCQM4MEvaluator()
results = evaluator.eval(
dict(
y_pred=ensembled_predictions.predictions,
y_true=_load_valid_labels()))
logging.info('MAE on validation dataset: %f', results['mae'])
def clip_predictions(predictions: _Predictions) -> _Predictions:
return predictions._replace(
predictions=np.clip(predictions.predictions, 0., _CLIP_VALUE))
def _generate_test_prediction_file(test_predictions: np.ndarray,
output_path: pathlib.Path) -> pathlib.Path:
"""Generates the final file for submission."""
# Check that predictions are not nuts.
assert test_predictions.dtype in [np.float64, np.float32]
assert not np.any(np.isnan(test_predictions))
assert np.all(np.isfinite(test_predictions))
assert test_predictions.min() >= 0.
assert test_predictions.max() <= 40.
# Too risky to overwrite.
if output_path.exists():
raise ValueError(f'{output_path} already exists')
# Write to a local directory, and copy to final path (possibly cns).
# It is not possible to write directlt on CNS.
evaluator = lsc.PCQM4MEvaluator()
evaluator.save_test_submission(dict(y_pred=test_predictions), output_path)
return output_path
def merge_complementary_results(split: str, results_a: _Predictions,
results_b: _Predictions) -> _Predictions:
"""Merges two prediction results with no overlap."""
indices_a = set(results_a.indices)
indices_b = set(results_b.indices)
assert not indices_a.intersection(indices_b)
if split == 'test':
merged_indices = list(sorted(indices_a | indices_b))
expected_indices = datasets.load_splits()[split]
assert np.all(expected_indices == merged_indices)
predictions = np.concatenate([results_a.predictions, results_b.predictions])
indices = np.concatenate([results_a.indices, results_b.indices])
predictions = _sort_by_indices(
_Predictions(indices=indices, predictions=predictions))
return predictions
def ensemble_valid_predictions(
predictions_list: List[_Predictions]) -> _Predictions:
"""Ensembles a list of predictions."""
index_to_predictions = collections.defaultdict(list)
for predictions in predictions_list:
for idx, pred in zip(predictions.indices, predictions.predictions):
index_to_predictions[idx].append(pred)
for idx, ensemble_list in index_to_predictions.items():
if len(ensemble_list) != _NUM_SEEDS:
raise RuntimeError(
'Graph index in the validation set received wrong number of '
'predictions to ensemble.')
index_to_predictions = {
k: np.median(pred_list, axis=0)
for k, pred_list in index_to_predictions.items()
}
return _sort_by_indices(
_Predictions(
indices=np.array(list(index_to_predictions.keys())),
predictions=np.array(list(index_to_predictions.values()))))
def ensemble_test_predictions(
predictions_list: List[_Predictions]) -> _Predictions:
"""Ensembles a list of predictions."""
predictions = np.median([pred.predictions for pred in predictions_list],
axis=0)
common_indices = predictions_list[0].indices
for preds in predictions_list[1:]:
assert np.all(preds.indices == common_indices)
return _Predictions(predictions=predictions, indices=common_indices)
def create_submission_from_predictions(
output_path: pathlib.Path, test_predictions: _Predictions) -> pathlib.Path:
"""Creates a submission for predictions on a path."""
assert _SPLIT.value == 'test'
output_path = _generate_test_prediction_file(
test_predictions.predictions,
output_path=output_path / 'submission_files')
return output_path / 'y_pred_pcqm4m.npz'
def merge_predictions(split: str) -> List[_Predictions]:
"""Generates features merged from conformer and non-conformer predictions."""
merged_predictions: List[_Predictions] = []
seed = _SEED_START.value
# Load conformer and non-conformer predictions.
for unused_seed_group in (0, 1):
for k in range(_NUM_KFOLD_SPLITS):
conformer_predictions: _Predictions = load_predictions(
f'{_CONFORMER_PATH.value}/k{k}_seed{seed}', split)
non_conformer_predictions: _Predictions = load_predictions(
f'{_NON_CONFORMER_PATH.value}/k{k}_seed{seed}', split)
merged_predictions.append(
merge_complementary_results(_SPLIT.value, conformer_predictions,
non_conformer_predictions))
seed += 1
return merged_predictions
def main(_):
split: str = _SPLIT.value
# Merge conformer and non-conformer predictions.
merged_predictions = merge_predictions(split)
# Clip before ensembling.
clipped_predictions = list(map(clip_predictions, merged_predictions))
# Ensemble predictions.
if split == 'valid':
ensembled_predictions = ensemble_valid_predictions(clipped_predictions)
else:
assert split == 'test'
ensembled_predictions = ensemble_test_predictions(clipped_predictions)
# Clip after ensembling.
ensembled_predictions = clip_predictions(ensembled_predictions)
ensembled_predictions_path = pathlib.Path(_OUTPUT_PATH.value)
ensembled_predictions_path.mkdir(parents=True, exist_ok=True)
with open(ensembled_predictions_path / f'{split}_predictions.dill',
'wb') as f:
dill.dump(ensembled_predictions, f)
if split == 'valid':
evaluate_valid_predictions(ensembled_predictions)
else:
assert split == 'test'
output_path = create_submission_from_predictions(ensembled_predictions_path,
ensembled_predictions)
logging.info('Submission files written to %s', output_path)
if __name__ == '__main__':
app.run(main)
+449
View File
@@ -0,0 +1,449 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PCQM4M-LSC Jaxline experiment."""
import datetime
import functools
import os
import signal
import threading
from typing import Iterable, Mapping, NamedTuple, Tuple
from absl import app
from absl import flags
from absl import logging
import chex
import dill
import haiku as hk
import jax
from jax.config import config as jax_config
import jax.numpy as jnp
from jaxline import experiment
from jaxline import platform
from jaxline import utils
import jraph
import numpy as np
import optax
import tensorflow as tf
import tree
# pylint: disable=g-bad-import-order
import dataset_utils
import datasets
import model
FLAGS = flags.FLAGS
def _get_step_date_label(global_step: int):
# Date removing microseconds.
date_str = datetime.datetime.now().isoformat().split('.')[0]
return f'step_{global_step}_{date_str}'
class _Predictions(NamedTuple):
predictions: np.ndarray
indices: np.ndarray
def tf1_ema(ema_value, current_value, decay, step):
"""Implements EMA with TF1-style decay warmup."""
decay = jnp.minimum(decay, (1.0 + step) / (10.0 + step))
return ema_value * decay + current_value * (1 - decay)
def _sort_predictions_by_indices(predictions: _Predictions):
sorted_order = np.argsort(predictions.indices)
return _Predictions(
predictions=predictions.predictions[sorted_order],
indices=predictions.indices[sorted_order])
class Experiment(experiment.AbstractExperiment):
"""OGB Graph Property Prediction GraphNet experiment."""
CHECKPOINT_ATTRS = {
'_params': 'params',
'_opt_state': 'opt_state',
'_network_state': 'network_state',
'_ema_network_state': 'ema_network_state',
'_ema_params': 'ema_params',
}
def __init__(self, mode, init_rng, config):
"""Initializes experiment."""
super(Experiment, self).__init__(mode=mode, init_rng=init_rng)
if mode not in ('train', 'eval', 'train_eval_multithreaded'):
raise ValueError(f'Invalid mode {mode}.')
# Do not use accelerators in data pipeline.
tf.config.experimental.set_visible_devices([], device_type='gpu')
tf.config.experimental.set_visible_devices([], device_type='tpu')
self.mode = mode
self.init_rng = init_rng
self.config = config
self.loss = None
self.forward = None
# Needed for checkpoint restore.
self._params = None
self._network_state = None
self._opt_state = None
self._ema_network_state = None
self._ema_params = None
# _ _
# | |_ _ __ __ _(_)_ __
# | __| "__/ _` | | "_ \
# | |_| | | (_| | | | | |
# \__|_| \__,_|_|_| |_|
#
def step(self, global_step: jnp.ndarray, rng: jnp.ndarray, **unused_args):
"""See Jaxline base class."""
if self.loss is None:
self._train_init()
graph = next(self._train_input)
out = self.update_parameters(
self._params,
self._ema_params,
self._network_state,
self._ema_network_state,
self._opt_state,
global_step,
rng,
graph._asdict())
(self._params, self._ema_params, self._network_state,
self._ema_network_state, self._opt_state, scalars) = out
return utils.get_first(scalars)
def _construct_loss_config(self):
loss_config = getattr(model, self.config.model.loss_config_name)
if self.config.model.loss_config_name == 'RegressionLossConfig':
return loss_config(
mean=datasets.NORMALIZE_TARGET_MEAN,
std=datasets.NORMALIZE_TARGET_STD,
kwargs=self.config.model.loss_kwargs)
else:
raise ValueError('Unknown Loss Config')
def _train_init(self):
self.loss = hk.transform_with_state(self._loss)
self._train_input = utils.py_prefetch(
lambda: self._build_numpy_dataset_iterator('train', is_training=True))
init_stacked_graphs = next(self._train_input)
init_key = utils.bcast_local_devices(self.init_rng)
p_init = jax.pmap(self.loss.init)
self._params, self._network_state = p_init(init_key,
**init_stacked_graphs._asdict())
# Learning rate scheduling.
lr_schedule = optax.warmup_cosine_decay_schedule(
**self.config.optimizer.lr_schedule)
self.optimizer = getattr(optax, self.config.optimizer.name)(
learning_rate=lr_schedule, **self.config.optimizer.optimizer_kwargs)
self._opt_state = jax.pmap(self.optimizer.init)(self._params)
self.update_parameters = jax.pmap(self._update_parameters, axis_name='i')
if self.config.ema:
self._ema_params = self._params
self._ema_network_state = self._network_state
def _loss(
self, **graph: Mapping[str, chex.ArrayTree]) -> chex.ArrayTree:
graph = jraph.GraphsTuple(**graph)
model_instance = model.GraphPropertyEncodeProcessDecode(
loss_config=self._construct_loss_config(), **self.config.model)
loss, scalars = model_instance.get_loss(graph)
return loss, scalars
def _maybe_save_predictions(
self,
predictions: jnp.ndarray,
split: str,
global_step: jnp.ndarray,
):
if not self.config.predictions_dir:
return
output_dir = os.path.join(self.config.predictions_dir,
_get_step_date_label(global_step))
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, split + '.dill')
with open(output_path, 'wb') as f:
dill.dump(predictions, f)
logging.info('Saved %s predictions at: %s', split, output_path)
def _build_numpy_dataset_iterator(self, split: str, is_training: bool):
dynamic_batch_size_config = (
self.config.training.dynamic_batch_size
if is_training else self.config.evaluation.dynamic_batch_size)
return dataset_utils.build_dataset_iterator(
split=split,
dynamic_batch_size_config=dynamic_batch_size_config,
sample_random=self.config.sample_random,
debug=self.config.debug,
is_training=is_training,
**self.config.dataset_config)
def _update_parameters(
self,
params: hk.Params,
ema_params: hk.Params,
network_state: hk.State,
ema_network_state: hk.State,
opt_state: optax.OptState,
global_step: jnp.ndarray,
rng: jnp.ndarray,
graph: jraph.GraphsTuple,
) -> Tuple[hk.Params, hk.Params, hk.State, hk.State, optax.OptState,
chex.ArrayTree]:
"""Updates parameters."""
def get_loss(*x, **graph):
(loss, scalars), network_state = self.loss.apply(*x, **graph)
return loss, (scalars, network_state)
grad_loss_fn = jax.grad(get_loss, has_aux=True)
scaled_grads, (scalars, network_state) = grad_loss_fn(
params, network_state, rng, **graph)
grads = jax.lax.psum(scaled_grads, axis_name='i')
updates, opt_state = self.optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
if ema_params is not None:
ema = lambda x, y: tf1_ema(x, y, self.config.ema_decay, global_step)
ema_params = jax.tree_multimap(ema, ema_params, params)
ema_network_state = jax.tree_multimap(ema, ema_network_state,
network_state)
return params, ema_params, network_state, ema_network_state, opt_state, scalars
# _
# _____ ____ _| |
# / _ \ \ / / _` | |
# | __/\ V / (_| | |
# \___| \_/ \__,_|_|
#
def evaluate(self, global_step: jnp.ndarray, rng: jnp.ndarray,
**unused_kwargs) -> chex.ArrayTree:
"""See Jaxline base class."""
if self.forward is None:
self._eval_init()
if self.config.ema:
params = utils.get_first(self._ema_params)
state = utils.get_first(self._ema_network_state)
else:
params = utils.get_first(self._params)
state = utils.get_first(self._network_state)
rng = utils.get_first(rng)
split = self.config.evaluation.split
predictions, scalars = self._get_predictions(
params, state, rng,
utils.py_prefetch(
functools.partial(
self._build_numpy_dataset_iterator, split, is_training=False)))
self._maybe_save_predictions(predictions, split, global_step[0])
return scalars
def _sum_regression_scalars(self, preds: jnp.ndarray,
graph: jraph.GraphsTuple) -> chex.ArrayTree:
"""Creates unnormalised values for accumulation."""
targets = graph.globals['target']
graph_mask = jraph.get_graph_padding_mask(graph)
# Sum for accumulation, normalise later since there are a
# variable number of graphs per batch.
mae = model.sum_with_mask(jnp.abs(targets - preds), graph_mask)
mse = model.sum_with_mask((targets - preds)**2, graph_mask)
count = jnp.sum(graph_mask)
return {'values': {'mae': mae.item(), 'mse': mse.item()},
'counts': {'mae': count.item(), 'mse': count.item()}}
def _get_prediction(
self,
params: hk.Params,
state: hk.State,
rng: jnp.ndarray,
graph: jraph.GraphsTuple,
) -> np.ndarray:
"""Returns predictions for all the graphs in the dataset split."""
model_out, _ = self.eval_apply(params, state, rng, **graph._asdict())
prediction = np.squeeze(model_out['globals'], axis=1)
return prediction
def _get_predictions(
self,
params: hk.Params,
state: hk.State,
rng: jnp.ndarray,
graph_iterator: Iterable[jraph.GraphsTuple],
) -> Tuple[_Predictions, chex.ArrayTree]:
all_scalars = []
predictions = []
graph_indices = []
for i, graph in enumerate(graph_iterator):
prediction = self._get_prediction(params, state, rng, graph)
if 'target' in graph.globals and not jnp.isnan(
graph.globals['target']).any():
scalars = self._sum_regression_scalars(prediction, graph)
all_scalars.append(scalars)
num_padding_graphs = jraph.get_number_of_padding_with_graphs_graphs(graph)
num_valid_graphs = len(graph.n_node) - num_padding_graphs
depadded_prediction = prediction[:num_valid_graphs]
predictions.append(depadded_prediction)
graph_indices.append(graph.globals['graph_index'][:num_valid_graphs])
if i % 1000 == 0:
logging.info('Generated predictions for %d batches so far', i + 1)
predictions = _sort_predictions_by_indices(
_Predictions(
predictions=np.concatenate(predictions),
indices=np.concatenate(graph_indices)))
if all_scalars:
sum_all_args = lambda *l: sum(l)
# Sum over graphs in the dataset.
accum_scalars = tree.map_structure(sum_all_args, *all_scalars)
scalars = tree.map_structure(lambda x, y: x / y, accum_scalars['values'],
accum_scalars['counts'])
else:
scalars = {}
return predictions, scalars
def _eval_init(self):
self.forward = hk.transform_with_state(self._forward)
self.eval_apply = jax.jit(self.forward.apply)
def _forward(self, **graph: Mapping[str, chex.ArrayTree]) -> chex.ArrayTree:
graph = jraph.GraphsTuple(**graph)
model_instance = model.GraphPropertyEncodeProcessDecode(
loss_config=self._construct_loss_config(), **self.config.model)
return model_instance(graph)
def _restore_state_to_in_memory_checkpointer(restore_path):
"""Initializes experiment state from a checkpoint."""
# Load pretrained experiment state.
python_state_path = os.path.join(restore_path, 'checkpoint.dill')
with open(python_state_path, 'rb') as f:
pretrained_state = dill.load(f)
logging.info('Restored checkpoint from %s', python_state_path)
# Assign state to a dummy experiment instance for the in-memory checkpointer,
# broadcasting to devices.
dummy_experiment = Experiment(
mode='train', init_rng=0, config=FLAGS.config.experiment_kwargs.config)
for attribute, key in Experiment.CHECKPOINT_ATTRS.items():
setattr(dummy_experiment, attribute,
utils.bcast_local_devices(pretrained_state[key]))
jaxline_state = dict(
global_step=pretrained_state['global_step'],
experiment_module=dummy_experiment)
snapshot = utils.SnapshotNT(0, jaxline_state)
# Finally, seed the jaxline `utils.InMemoryCheckpointer` global dict.
utils.GLOBAL_CHECKPOINT_DICT['latest'] = utils.CheckpointNT(
threading.local(), [snapshot])
def _save_state_from_in_memory_checkpointer(
save_path, experiment_class: experiment.AbstractExperiment):
"""Saves experiment state to a checkpoint."""
logging.info('Saving model.')
for checkpoint_name, checkpoint in utils.GLOBAL_CHECKPOINT_DICT.items():
if not checkpoint.history:
logging.info('Nothing to save in "%s"', checkpoint_name)
continue
pickle_nest = checkpoint.history[-1].pickle_nest
global_step = pickle_nest['global_step']
state_dict = {'global_step': global_step}
for attribute, key in experiment_class.CHECKPOINT_ATTRS.items():
state_dict[key] = utils.get_first(
getattr(pickle_nest['experiment_module'], attribute))
save_dir = os.path.join(
save_path, checkpoint_name, _get_step_date_label(global_step))
python_state_path = os.path.join(save_dir, 'checkpoint.dill')
os.makedirs(save_dir, exist_ok=True)
with open(python_state_path, 'wb') as f:
dill.dump(state_dict, f)
logging.info(
'Saved "%s" checkpoint to %s', checkpoint_name, python_state_path)
def _setup_signals(save_model_fn):
"""Sets up a signal for model saving."""
# Save a model on Ctrl+C.
def sigint_handler(unused_sig, unused_frame):
# Ideally, rather than saving immediately, we would then "wait" for a good
# time to save. In practice this reads from an in-memory checkpoint that
# only saves every 30 seconds or so, so chances of race conditions are very
# small.
save_model_fn()
logging.info(r'Use `Ctrl+\` to save and exit.')
# Exit on `Ctrl+\`, saving a model.
prev_sigquit_handler = signal.getsignal(signal.SIGQUIT)
def sigquit_handler(unused_sig, unused_frame):
# Restore previous handler early, just in case something goes wrong in the
# next lines, so it is possible to press again and exit.
signal.signal(signal.SIGQUIT, prev_sigquit_handler)
save_model_fn()
logging.info(r'Exiting on `Ctrl+\`')
# Re-raise for clean exit.
os.kill(os.getpid(), signal.SIGQUIT)
signal.signal(signal.SIGINT, sigint_handler)
signal.signal(signal.SIGQUIT, sigquit_handler)
def main(argv, experiment_class: experiment.AbstractExperiment):
# Maybe restore a model.
restore_path = FLAGS.config.restore_path
if restore_path:
_restore_state_to_in_memory_checkpointer(restore_path)
# Maybe save a model.
save_dir = os.path.join(FLAGS.config.checkpoint_dir, 'models')
if FLAGS.config.one_off_evaluate:
save_model_fn = lambda: None # No need to save checkpoint in this case.
else:
save_model_fn = functools.partial(
_save_state_from_in_memory_checkpointer, save_dir, experiment_class)
_setup_signals(save_model_fn) # Save on Ctrl+C (continue) or Ctrl+\ (exit).
try:
platform.main(experiment_class, argv)
finally:
save_model_fn() # Save at the end of training or in case of exception.
if __name__ == '__main__':
jax_config.update('jax_debug_nans', False)
flags.mark_flag_as_required('config')
app.run(lambda argv: main(argv, Experiment))
@@ -0,0 +1,67 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generate conformer features to be used for training/predictions."""
import multiprocessing as mp
import pickle
from typing import List
from absl import app
from absl import flags
import numpy as np
# pylint: disable=g-bad-import-order
import conformer_utils
import datasets
_SPLITS = flags.DEFINE_spaceseplist(
'splits', ['test'], 'Splits to compute conformer features for.')
_OUTPUT_FILE = flags.DEFINE_string(
'output_file',
None,
required=True,
help='Output file name to write the generated conformer features to.')
_NUM_PROCS = flags.DEFINE_integer(
'num_parallel_procs', 64,
'Number of parallel processes to use for conformer generation.')
def generate_conformer_features(smiles: List[str]) -> List[np.ndarray]:
# Conformer generation is a CPU-bound task and hence can get a boost from
# parallel processing.
# To avoid GIL, we choose multiprocessing instead of the
# simpler multi-threading option here for parallel computing.
with mp.Pool(_NUM_PROCS.value) as pool:
return list(pool.map(conformer_utils.compute_conformer, smiles))
def main(_):
smiles = datasets.load_smile_strings(with_labels=False)
indices = set()
for split in _SPLITS.value:
indices.update(datasets.load_splits()[split])
smiles = [smiles[i] for i in sorted(indices)]
conformers = generate_conformer_features(smiles)
smiles_to_conformers = dict(zip(smiles, conformers))
with open(_OUTPUT_FILE.value, 'wb') as f:
pickle.dump(smiles_to_conformers, f)
if __name__ == '__main__':
app.run(main)
+50
View File
@@ -0,0 +1,50 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates the k-fold validation splits."""
import os
import pickle
from absl import app
from absl import flags
from absl import logging
import numpy as np
# pylint: disable=g-bad-import-order
import datasets
_OUTPUT_DIR = flags.DEFINE_string(
'output_dir', None, required=True,
help='Output directory to write the splits to')
K = 10
def main(argv):
del argv
valid_indices = datasets.load_splits()['valid']
k_splits = np.split(valid_indices, K)
os.makedirs(_OUTPUT_DIR.value, exist_ok=True)
for k_i, split in enumerate(k_splits):
fname = os.path.join(_OUTPUT_DIR.value, f'{k_i}.pkl')
with open(fname, 'wb') as f:
pickle.dump(split, f)
logging.info('Saved: %s', fname)
if __name__ == '__main__':
app.run(main)
File diff suppressed because it is too large Load Diff
+21
View File
@@ -0,0 +1,21 @@
wheel
absl-py>=0.12.0
chex>=0.0.7
dill>=0.3.3
dm-haiku>=0.0.4
# jaxline
git+https://github.com/deepmind/jaxline@927695a0b88e480d085590736e2daa36c2f2c68b
optax>=0.0.8
ml_collections
numpy>=1.16.4
tensorflow>=2.5.0
tensorflow-datasets>=4.3.0
dm-tree>=0.1.6
ogb>=1.3.1
# graph_nets
git+git://github.com/deepmind/graph_nets.git@64771dff0d74ca8e77b1f1dcd5a7d26634356d61
rdkit-pypi>=2021.3.2.3
jaxlib>=0.1.57+cuda110
# jraph
git+git://github.com/deepmind/jraph.git@80d2f9e4d82f841a71d56ba44fa9e8781e93ae1f
google-cloud-storage
+48
View File
@@ -0,0 +1,48 @@
#!/bin/bash
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
set -x
while getopts ":r:" opt; do
case ${opt} in
r )
TASK_ROOT=$OPTARG
;;
\? )
echo "Usage: run_training.sh -r <Task root directory>"
;;
: )
echo "Invalid option: $OPTARG requires an argument" 1>&2
;;
esac
done
shift $((OPTIND -1))
# Get this script's directory.
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
DATA_ROOT=${TASK_ROOT}/data/
python "${SCRIPT_DIR}"/generate_validation_splits.py \
--output_dir="${DATA_ROOT}/k_fold_splits"
mkdir -p ${DATA_ROOT}/preprocessed/
python "${SCRIPT_DIR}"/generate_conformer_features.py \
--splits="test valid train" \
--num_parallel_procs=32 \
--output_file="${DATA_ROOT}/preprocessed/smile_to_conformer.pkl"
+107
View File
@@ -0,0 +1,107 @@
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Generates test predictions from trained model checkpoints.
set -e
set -x
while getopts ":r:" opt; do
case ${opt} in
r )
TASK_ROOT=$OPTARG
;;
\? )
echo "Usage: run_pretrained_eval.sh -r <Task root directory>"
;;
: )
echo "Invalid option: $OPTARG requires an argument" 1>&2
;;
esac
done
shift $((OPTIND -1))
# Get this script's directory.
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
SPLIT="test"
DATA_ROOT=${TASK_ROOT}/data/ # For valid k-fold splits.
MODELS_ROOT=${TASK_ROOT}/models/
CACHED_CONFORMERS_DIR=${TASK_ROOT}/online_conformers/$SPLIT
CACHED_CONFORMERS_FILE=${CACHED_CONFORMERS_DIR}/smiles_to_conformers.pkl
OUTPUT_DIR=${TASK_ROOT}/predictions/
# First generate and cache the conformer feature data for the test split.
mkdir -p ${CACHED_CONFORMERS_DIR}
time python "${SCRIPT_DIR}"/generate_conformer_features.py \
--splits=$SPLIT \
--num_parallel_procs=32 \
--output_file=${CACHED_CONFORMERS_FILE}
seed=42
# Share GPU between two runs by disabling XLA memory pre-allocation.
export XLA_PYTHON_CLIENT_PREALLOCATE=false
for seed_group in 0 1; do
for k in `seq 0 9`; do
# Conformer
predictions_dir="${OUTPUT_DIR}/predictions/${SPLIT}/conformer/k${k}_seed${seed}"
time python "${SCRIPT_DIR}"/experiment.py \
--jaxline_mode="eval" \
--config="${SCRIPT_DIR}"/config.py \
--config.experiment_kwargs.config.evaluation.split=$SPLIT \
--config.restore_path=${MODELS_ROOT}/conformer/k${k}_seed${seed} \
--config.experiment_kwargs.config.predictions_dir=${predictions_dir} \
--config.experiment_kwargs.config.dataset_config.data_root=${DATA_ROOT} \
--config.experiment_kwargs.config.dataset_config.k_fold_split_id=$k \
--config.experiment_kwargs.config.dataset_config.cached_conformers_file=${CACHED_CONFORMERS_FILE} \
--config.experiment_kwargs.config.dataset_config.filter_in_or_out_samples_with_nans_in_conformers="out" \
--config.experiment_kwargs.config.model.latent_size=256 \
--config.experiment_kwargs.config.model.mlp_hidden_size=1024 \
--config.experiment_kwargs.config.model.num_message_passing_steps=32 \
--config.one_off_evaluate &
# Non-Conformer
predictions_dir="${OUTPUT_DIR}/predictions/${SPLIT}/non_conformer/k${k}_seed${seed}"
time python "${SCRIPT_DIR}"/experiment.py \
--jaxline_mode="eval" \
--config="${SCRIPT_DIR}"/config.py \
--config.experiment_kwargs.config.evaluation.split=$SPLIT \
--config.restore_path=${MODELS_ROOT}/non_conformer/k${k}_seed${seed} \
--config.experiment_kwargs.config.predictions_dir=${predictions_dir} \
--config.experiment_kwargs.config.dataset_config.data_root=${DATA_ROOT} \
--config.experiment_kwargs.config.dataset_config.k_fold_split_id=$k \
--config.experiment_kwargs.config.dataset_config.cached_conformers_file=${CACHED_CONFORMERS_FILE} \
--config.experiment_kwargs.config.dataset_config.filter_in_or_out_samples_with_nans_in_conformers="in" \
--config.experiment_kwargs.config.model.num_message_passing_steps=50 \
--config.experiment_kwargs.config.model.add_relative_distance=false \
--config.experiment_kwargs.config.model.add_relative_displacement=false \
--config.one_off_evaluate &
wait
((seed=seed+1))
done
done
# Ensemble predictions.
time python "${SCRIPT_DIR}"/ensemble_predictions.py \
--seed_start=42 \
--split=$SPLIT \
--conformer_path="${OUTPUT_DIR}/predictions/${SPLIT}/conformer" \
--non_conformer_path="${OUTPUT_DIR}/predictions/${SPLIT}/non_conformer" \
--output_path="${OUTPUT_DIR}/ensembled_predictions/"
+134
View File
@@ -0,0 +1,134 @@
#!/bin/bash
# Copyright 2021 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
set -x
while getopts ":r:" opt; do
case ${opt} in
r )
TASK_ROOT=$OPTARG
;;
\? )
echo "Usage: run_training.sh -r <Task root directory>"
;;
: )
echo "Invalid option: $OPTARG requires an argument" 1>&2
;;
esac
done
shift $((OPTIND -1))
# Get this script's directory.
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
echo "
These scripts are provided for illustrative purposes. It is not practical for
actual training since it only uses a single machine, and likely requires
reducing the batch size and/or model size to fit on a single GPU.
For the actual submission we used training distributed in different ways:
* We used 4x Cloud TPU v4s to implement batch parallelism, so batch
size is effectively 8 times larger than the values in the config.
* We ran the online early-stopping evaluator on a separate machine with an
NVIDIA V100 GPU.
* We run identical replicas of all of the above for each of the 40 models
trained.
Using those mechanisms, training runs at ~4 and 6 steps per second, reaching
500k steps in about 1.5 days.
"
echo "
Pre-requisites (See README):
* cd into directory containing the 'ogb_lsc' folder.
* Python dependencies have been installed.
* pre-processed data is available at the hardcoded paths
"
read -p "Press enter to continue"
DATA_ROOT=${TASK_ROOT}/data/
CHECKPOINT_DIR=${TASK_ROOT}/checkpoints/
CACHED_CONFORMERS_FILE="${DATA_ROOT}/preprocessed/smile_to_conformer.pkl"
# We run two seeds for each model of the k=10 k-fold.
BASE_SEED=42
for SEED_OFFSET in 0 10; do
for K_FOLD_INDEX in {0..9}; do
MODEL_SEED=`expr ${BASE_SEED} + ${SEED_OFFSET} + ${K_FOLD_INDEX}`
echo "Running k=${K_FOLD_INDEX} with init seed ${MODEL_SEED}"
# This runs training (each model is trained on train split + 90% of the
# validation split) with early stopping, storing both "latest" model and
# "best" early-stopped model at `--config.checkpoint_dir`.
# Models are early stopped based on accuracy of on 10% of the validation data
# (each K_FOLD_INDEX leaves a different 10% of data out from training) left
# out from training.
# Models are stored at the end of training. Intermediate models can also be
# stored while training by sending a SIGINT signal (Ctrl+C) which will not
# interrupt the training.
# It is possible to interrupt training using (Ctrl+\) and then continue
# passing the corresponding `--config.restore_path=${RESTORE_PATH}` which is
# stored when interrupting.
# Conformer.
SUFFIX="conformer/k${K_FOLD_INDEX}_seed${MODEL_SEED}"
echo "Running ${SUFFIX}"
python "${SCRIPT_DIR}"/experiment.py \
--jaxline_mode="train_eval_multithreaded" \
--config="${SCRIPT_DIR}"/config.py \
--config.random_seed=${MODEL_SEED} \
--config.checkpoint_dir=${CHECKPOINT_DIR}/${SUFFIX} \
--config.experiment_kwargs.config.dataset_config.data_root=${DATA_ROOT} \
--config.experiment_kwargs.config.dataset_config.k_fold_split_id=${K_FOLD_INDEX} \
--config.experiment_kwargs.config.dataset_config.num_k_fold_splits=10 \
--config.experiment_kwargs.config.dataset_config.cached_conformers_file=${CACHED_CONFORMERS_FILE} \
--config.experiment_kwargs.config.model.latent_size=256 \
--config.experiment_kwargs.config.model.mlp_hidden_size=1024 \
--config.experiment_kwargs.config.model.num_message_passing_steps=32 \
--config.experiment_kwargs.config.evaluation.split="valid"
# Non-Conformer.
SUFFIX="non_conformer/k${K_FOLD_INDEX}_seed${MODEL_SEED}"
echo "Running ${SUFFIX}"
python "${SCRIPT_DIR}"/experiment.py \
--jaxline_mode="train_eval_multithreaded" \
--config="${SCRIPT_DIR}"/config.py \
--config.random_seed=${MODEL_SEED} \
--config.checkpoint_dir=${CHECKPOINT_DIR}/${SUFFIX} \
--config.experiment_kwargs.config.dataset_config.data_root=${DATA_ROOT} \
--config.experiment_kwargs.config.dataset_config.k_fold_split_id=${K_FOLD_INDEX} \
--config.experiment_kwargs.config.dataset_config.num_k_fold_splits=10 \
--config.experiment_kwargs.config.dataset_config.cached_conformers_file=${CACHED_CONFORMERS_FILE} \
--config.experiment_kwargs.config.model.num_message_passing_steps=50 \
--config.experiment_kwargs.config.model.add_relative_distance=false \
--config.experiment_kwargs.config.model.add_relative_displacement=false \
--config.experiment_kwargs.config.evaluation.split="valid"
done
done
# Each of the 40 (two for each value of k, and both conformer and non conformer)
# jobs, it will generate paths of the form with the early-stopped models:
# RESTORE_PATH=${--config.checkpoint_dir}/models/best/step_${STEP}_${TIMESTAMP}
# These can then be used as the "RESTORE_PATHS" for `./run_pretrained_eval.sh`.
echo "Done"