mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-31 21:15:21 +08:00
Add OGB-LSC code.
PiperOrigin-RevId: 379591158
This commit is contained in:
committed by
Saran Tunyasuvunakool
parent
88f674ad87
commit
a3c133404e
@@ -0,0 +1,29 @@
|
||||
# DeepMind entry for PCQM4M-LSC
|
||||
|
||||
This repository contains DeepMind's entry to the [PCWM4M-LSC](https://ogb.stanford.edu/kddcup2021/pcqm4m/) (quantum chemistry) and
|
||||
[MAG240M-LSC](https://ogb.stanford.edu/kddcup2021/mag240m/) (academic graph)
|
||||
tracks of the [OGB Large-Scale Challenge](https://ogb.stanford.edu/kddcup2021/)
|
||||
(OGB-LSC).
|
||||
|
||||
For full details regarding this entry, please see our [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf).
|
||||
|
||||
## Code structure
|
||||
|
||||
* `pcq/`: Scripts for training, evaluating on the [PCQ dataset](https://ogb.stanford.edu/docs/graphprop/).
|
||||
* `mag/`: Scripts for training, evaluating on the [MAG dataset](https://ogb.stanford.edu/docs/nodeprop/).
|
||||
|
||||
## Citation
|
||||
|
||||
To cite this work:
|
||||
|
||||
```latex
|
||||
@article{deepmind2021ogb,
|
||||
author = {Ravichandra Addanki and Peter Battaglia and David Budden and Andreea
|
||||
Deac and Jonathan Godwin and Thomas Keck and Wai Lok Sibon Li and Alvaro
|
||||
Sanchez-Gonzalez and Jacklynn Stott and Shantanu Thakoor and Petar
|
||||
Veli\v{c}kovi\'{c}},
|
||||
title = {Large-scale graph representation learning with very deep GNNs and
|
||||
self-supervision},
|
||||
year = {2021},
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,128 @@
|
||||
# DeepMind entry for MAG240M-LSC
|
||||
|
||||
This repository contains DeepMind's entry to the [MAG240M-LSC](https://ogb.stanford.edu/kddcup2021/mag240m/) (academic graph) track of the
|
||||
[OGB Large-Scale Challenge](https://ogb.stanford.edu/kddcup2021/) (OGB-LSC).
|
||||
|
||||
For full details regarding this entry, please see our [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf).
|
||||
|
||||
## DeepMind MAG Team ("Academic")
|
||||
|
||||
(in alphabetical order)
|
||||
|
||||
- Ravichandra Addanki
|
||||
- Peter Battaglia
|
||||
- David Budden
|
||||
- Andreea Deac
|
||||
- Jonathan Godwin
|
||||
- Thomas Keck
|
||||
- Alvaro Sanchez-Gonzalez
|
||||
- Jacklynn Stott
|
||||
- Shantanu Thakoor
|
||||
- Petar Veličković
|
||||
|
||||
## Performance
|
||||
|
||||
Our final test set performance was achieved by pooling an ensemble of 10 folds.
|
||||
See [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf) for details.
|
||||
|
||||
Each model was trained for < 72 hours using 4x Google Cloud TPUv4 and 1x AMD
|
||||
EPYC 7B12 64-core CPU @2.25GHz.
|
||||
|
||||
Inference takes < 12 hours on 4x NVIDIA V100 16GB GPU and 1x Intel Xeon Gold
|
||||
6148 20-core CPU @2.40GHz.
|
||||
|
||||
# Running our model
|
||||
|
||||
## Setup
|
||||
|
||||
You can set up Python virtual environment (you might need to install the
|
||||
`python3-venv` package first) with all needed dependencies inside the forked
|
||||
`deepmind_research` repository using:
|
||||
|
||||
```bash
|
||||
python3 -m venv /tmp/mag_venv
|
||||
source /tmp/mag_venv/bin/activate
|
||||
pip3 install --upgrade pip setuptools wheel
|
||||
pip3 install -r ogb_lsc/mag/requirements.txt
|
||||
```
|
||||
|
||||
## Download and pre-process data
|
||||
|
||||
|
||||
**1. Download the dataset using the contest toolkit ([here](https://ogb.stanford.edu/kddcup2021/mag240m/#dataset)) to a local directory
|
||||
`ROOT`.**
|
||||
|
||||
**2. Run this script to reorganize the data into a flat directory structure with
|
||||
transparent names**
|
||||
|
||||
```bash
|
||||
/bin/bash organize_data.sh -r ROOT
|
||||
```
|
||||
|
||||
Once this completes, a new directory `ROOT/mag240m_kddcup2021/raw` will be
|
||||
created, with contents:
|
||||
|
||||
- `node_feat.npy`
|
||||
- `node_label.npy`
|
||||
- `node_year.npy`
|
||||
- `author_affiliated_with_institution_edges.npy`
|
||||
- `author_writes_paper_edges.npy`
|
||||
- `paper_cites_paper_edges.npy`
|
||||
- `train_idx.npy`
|
||||
- `valid_idx.npy`
|
||||
- `test_idx.npy`
|
||||
|
||||
We refer to this as the "raw" data.
|
||||
|
||||
**3. To run the preprocessing code**
|
||||
```bash
|
||||
/bin/bash run_preprocessing.sh -r ROOT
|
||||
```
|
||||
|
||||
The pre-processing is very time- and memory-consuming, and should only be run
|
||||
to verify the full pipeline. You can download the pre-processed data using the
|
||||
following script, for use in training and evaluating models:
|
||||
|
||||
```bash
|
||||
python3 download_mag.py --task_root=${HOME}/mag --payload="data"
|
||||
```
|
||||
|
||||
|
||||
## Reproducing our final results
|
||||
|
||||
We have provided pre-trained weights of our final submission for convenience.
|
||||
They can be downloaded with:
|
||||
```
|
||||
python3 download_mag.py --task_root=${HOME}/mag --payload="models"
|
||||
```
|
||||
Then to reproduce our final results, please run `bash run_pretrain_eval.sh`.
|
||||
|
||||
|
||||
## Retraining our model
|
||||
|
||||
Disclaimer: This script is provided for illustrative purposes. It is not
|
||||
practical for actual training since it only uses a single machine, and likely
|
||||
requires reducing the batch size and/or model size to fit on a single GPU.
|
||||
|
||||
If you still want to train a model, please run `run_training.sh`. To simply
|
||||
validate that the code is running correctly on your hardware setup, consider
|
||||
setting `debug=True` in `config.py`, which trains a smaller model.
|
||||
|
||||
|
||||
# Citation
|
||||
|
||||
To cite this work (together with our PCQM4M-LSC entry):
|
||||
|
||||
```latex
|
||||
@article{deepmind2021ogb,
|
||||
author = {Ravichandra Addanki and Peter Battaglia and David Budden and Andreea
|
||||
Deac and Jonathan Godwin and Thomas Keck and Wai Lok Sibon Li and Alvaro
|
||||
Sanchez-Gonzalez and Jacklynn Stott and Shantanu Thakoor and Petar
|
||||
Veli\v{c}kovi\'{c}},
|
||||
title = {Large-scale graph representation learning with very deep GNNs and
|
||||
self-supervision},
|
||||
year = {2021},
|
||||
}
|
||||
```
|
||||
|
||||
Our technical report can be found [here](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf).
|
||||
@@ -0,0 +1,138 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dynamic batching utilities."""
|
||||
|
||||
from typing import Generator, Iterable, Iterator, Sequence, Tuple
|
||||
|
||||
import jax.tree_util as tree
|
||||
import jraph
|
||||
import numpy as np
|
||||
|
||||
_NUMBER_FIELDS = ("n_node", "n_edge", "n_graph")
|
||||
|
||||
|
||||
def dynamically_batch(graphs_tuple_iterator: Iterator[jraph.GraphsTuple],
|
||||
n_node: int, n_edge: int,
|
||||
n_graph: int) -> Generator[jraph.GraphsTuple, None, None]:
|
||||
"""Dynamically batches trees with `jraph.GraphsTuples` to `graph_batch_size`.
|
||||
|
||||
Elements of the `graphs_tuple_iterator` will be incrementally added to a batch
|
||||
until the limits defined by `n_node`, `n_edge` and `n_graph` are reached. This
|
||||
means each element yielded by this generator
|
||||
|
||||
For situations where you have variable sized data, it"s useful to be able to
|
||||
have variable sized batches. This is especially the case if you have a loss
|
||||
defined on the variable shaped element (for example, nodes in a graph).
|
||||
|
||||
Args:
|
||||
graphs_tuple_iterator: An iterator of `jraph.GraphsTuples`.
|
||||
n_node: The maximum number of nodes in a batch.
|
||||
n_edge: The maximum number of edges in a batch.
|
||||
n_graph: The maximum number of graphs in a batch.
|
||||
|
||||
Yields:
|
||||
A `jraph.GraphsTuple` batch of graphs.
|
||||
|
||||
Raises:
|
||||
ValueError: if the number of graphs is < 2.
|
||||
RuntimeError: if the `graphs_tuple_iterator` contains elements which are not
|
||||
`jraph.GraphsTuple`s.
|
||||
RuntimeError: if a graph is found which is larger than the batch size.
|
||||
"""
|
||||
if n_graph < 2:
|
||||
raise ValueError("The number of graphs in a batch size must be greater or "
|
||||
f"equal to `2` for padding with graphs, got {n_graph}.")
|
||||
valid_batch_size = (n_node - 1, n_edge, n_graph - 1)
|
||||
accumulated_graphs = []
|
||||
num_accumulated_nodes = 0
|
||||
num_accumulated_edges = 0
|
||||
num_accumulated_graphs = 0
|
||||
for element in graphs_tuple_iterator:
|
||||
element_nodes, element_edges, element_graphs = _get_graph_size(element)
|
||||
if _is_over_batch_size(element, valid_batch_size):
|
||||
graph_size = element_nodes, element_edges, element_graphs
|
||||
graph_size = {k: v for k, v in zip(_NUMBER_FIELDS, graph_size)}
|
||||
batch_size = {k: v for k, v in zip(_NUMBER_FIELDS, valid_batch_size)}
|
||||
raise RuntimeError("Found graph bigger than batch size. Valid Batch "
|
||||
f"Size: {batch_size}, Graph Size: {graph_size}")
|
||||
|
||||
if not accumulated_graphs:
|
||||
# If this is the first element of the batch, set it and continue.
|
||||
accumulated_graphs = [element]
|
||||
num_accumulated_nodes = element_nodes
|
||||
num_accumulated_edges = element_edges
|
||||
num_accumulated_graphs = element_graphs
|
||||
continue
|
||||
else:
|
||||
# Otherwise check if there is space for the graph in the batch:
|
||||
if ((num_accumulated_graphs + element_graphs > n_graph - 1) or
|
||||
(num_accumulated_nodes + element_nodes > n_node - 1) or
|
||||
(num_accumulated_edges + element_edges > n_edge)):
|
||||
# If there is, add it to the batch
|
||||
batched_graph = _batch_np(accumulated_graphs)
|
||||
yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph)
|
||||
accumulated_graphs = [element]
|
||||
num_accumulated_nodes = element_nodes
|
||||
num_accumulated_edges = element_edges
|
||||
num_accumulated_graphs = element_graphs
|
||||
else:
|
||||
# Otherwise, return the old batch and start a new batch.
|
||||
accumulated_graphs.append(element)
|
||||
num_accumulated_nodes += element_nodes
|
||||
num_accumulated_edges += element_edges
|
||||
num_accumulated_graphs += element_graphs
|
||||
|
||||
# We may still have data in batched graph.
|
||||
if accumulated_graphs:
|
||||
batched_graph = _batch_np(accumulated_graphs)
|
||||
yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph)
|
||||
|
||||
|
||||
def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple:
|
||||
# Calculates offsets for sender and receiver arrays, caused by concatenating
|
||||
# the nodes arrays.
|
||||
offsets = np.cumsum(np.array([0] + [np.sum(g.n_node) for g in graphs[:-1]]))
|
||||
|
||||
def _map_concat(nests):
|
||||
concat = lambda *args: np.concatenate(args)
|
||||
return tree.tree_multimap(concat, *nests)
|
||||
|
||||
return jraph.GraphsTuple(
|
||||
n_node=np.concatenate([g.n_node for g in graphs]),
|
||||
n_edge=np.concatenate([g.n_edge for g in graphs]),
|
||||
nodes=_map_concat([g.nodes for g in graphs]),
|
||||
edges=_map_concat([g.edges for g in graphs]),
|
||||
globals=_map_concat([g.globals for g in graphs]),
|
||||
senders=np.concatenate([g.senders + o for g, o in zip(graphs, offsets)]),
|
||||
receivers=np.concatenate(
|
||||
[g.receivers + o for g, o in zip(graphs, offsets)]))
|
||||
|
||||
|
||||
def _get_graph_size(graph: jraph.GraphsTuple) -> Tuple[int, int, int]:
|
||||
n_node = np.sum(graph.n_node)
|
||||
n_edge = len(graph.senders)
|
||||
n_graph = len(graph.n_node)
|
||||
return n_node, n_edge, n_graph
|
||||
|
||||
|
||||
def _is_over_batch_size(
|
||||
graph: jraph.GraphsTuple,
|
||||
graph_batch_size: Iterable[int],
|
||||
) -> bool:
|
||||
graph_size = _get_graph_size(graph)
|
||||
return any([x > y for x, y in zip(graph_size, graph_batch_size)])
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Experiment config for MAG240M-LSC entry."""
|
||||
|
||||
from jaxline import base_config
|
||||
from ml_collections import config_dict
|
||||
|
||||
|
||||
def get_config(debug: bool = False) -> config_dict.ConfigDict:
|
||||
"""Get Jaxline experiment config."""
|
||||
config = base_config.get_base_config()
|
||||
config.random_seed = 42
|
||||
# E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below)
|
||||
config.restore_path = config_dict.placeholder(str)
|
||||
config.experiment_kwargs = config_dict.ConfigDict(
|
||||
dict(
|
||||
config=dict(
|
||||
debug=debug,
|
||||
predictions_dir=config_dict.placeholder(str),
|
||||
# 5 for model selection and early stopping, 50 for final eval.
|
||||
num_eval_iterations_to_ensemble=5,
|
||||
dataset_kwargs=dict(
|
||||
data_root='/data/',
|
||||
online_subsampling_kwargs=dict(
|
||||
max_nb_neighbours_per_type=[
|
||||
[[40, 20, 0, 40], [0, 0, 0, 0], [0, 0, 0, 0]],
|
||||
[[40, 20, 0, 40], [40, 0, 10, 0], [0, 0, 0, 0]],
|
||||
],
|
||||
remove_future_nodes=True,
|
||||
deduplicate_nodes=True,
|
||||
),
|
||||
ratio_unlabeled_data_to_labeled_data=10.0,
|
||||
k_fold_split_id=config_dict.placeholder(int),
|
||||
use_all_labels_when_not_training=False,
|
||||
use_dummy_adjacencies=debug,
|
||||
),
|
||||
optimizer=dict(
|
||||
name='adamw',
|
||||
kwargs=dict(weight_decay=1e-5, b1=0.9, b2=0.999),
|
||||
learning_rate_schedule=dict(
|
||||
use_schedule=True,
|
||||
base_learning_rate=1e-2,
|
||||
warmup_steps=50000,
|
||||
total_steps=config.get_ref('training_steps'),
|
||||
),
|
||||
),
|
||||
model_config=dict(
|
||||
mlp_hidden_sizes=[32] if debug else [512],
|
||||
latent_size=32 if debug else 256,
|
||||
num_message_passing_steps=2 if debug else 4,
|
||||
activation='relu',
|
||||
dropout_rate=0.3,
|
||||
dropedge_rate=0.25,
|
||||
disable_edge_updates=True,
|
||||
use_sent_edges=True,
|
||||
normalization_type='layer_norm',
|
||||
aggregation_function='sum',
|
||||
),
|
||||
training=dict(
|
||||
loss_config=dict(
|
||||
bgrl_loss_config=dict(
|
||||
stop_gradient_for_supervised_loss=False,
|
||||
bgrl_loss_scale=1.0,
|
||||
symmetrize=True,
|
||||
first_graph_corruption_config=dict(
|
||||
feature_drop_prob=0.4,
|
||||
edge_drop_prob=0.2,
|
||||
),
|
||||
second_graph_corruption_config=dict(
|
||||
feature_drop_prob=0.4,
|
||||
edge_drop_prob=0.2,
|
||||
),
|
||||
),
|
||||
),
|
||||
# GPU memory may require reducing the `256`s below to `48`.
|
||||
dynamic_batch_size_config=dict(
|
||||
n_node=256 if debug else 340 * 256,
|
||||
n_edge=512 if debug else 720 * 256,
|
||||
n_graph=4 if debug else 256,
|
||||
),
|
||||
),
|
||||
eval=dict(
|
||||
split='valid',
|
||||
ema_annealing_schedule=dict(
|
||||
use_schedule=True,
|
||||
base_rate=0.999,
|
||||
total_steps=config.get_ref('training_steps')),
|
||||
dynamic_batch_size_config=dict(
|
||||
n_node=256 if debug else 340 * 128,
|
||||
n_edge=512 if debug else 720 * 128,
|
||||
n_graph=4 if debug else 128,
|
||||
),
|
||||
))))
|
||||
|
||||
## Training loop config.
|
||||
config.training_steps = 500000
|
||||
config.checkpoint_dir = '/tmp/checkpoint/mag/'
|
||||
config.train_checkpoint_all_hosts = False
|
||||
config.log_train_data_interval = 10
|
||||
config.log_tensors_interval = 10
|
||||
config.save_checkpoint_interval = 30
|
||||
config.best_model_eval_metric = 'accuracy'
|
||||
config.best_model_eval_metric_higher_is_better = True
|
||||
|
||||
return config
|
||||
@@ -0,0 +1,136 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Builds CSR matrices which store the MAG graphs."""
|
||||
|
||||
import pathlib
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
import data_utils
|
||||
|
||||
Path = pathlib.Path
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
_DATA_FILES_AND_PARAMETERS = {
|
||||
'author_affiliated_with_institution_edges.npy': {
|
||||
'content_names': ('author', 'institution'),
|
||||
'use_boolean': False
|
||||
},
|
||||
'author_writes_paper_edges.npy': {
|
||||
'content_names': ('author', 'paper'),
|
||||
'use_boolean': False
|
||||
},
|
||||
'paper_cites_paper_edges.npy': {
|
||||
'content_names': ('paper', 'paper'),
|
||||
'use_boolean': True
|
||||
},
|
||||
}
|
||||
|
||||
flags.DEFINE_string('data_root', None, 'Data root directory')
|
||||
flags.DEFINE_boolean('skip_existing', True, 'Skips existing CSR files')
|
||||
|
||||
flags.mark_flags_as_required(['data_root'])
|
||||
|
||||
|
||||
def _read_edge_data(path):
|
||||
try:
|
||||
return np.load(path, mmap_mode='r')
|
||||
except FileNotFoundError:
|
||||
# If the file path can't be found by np.load, use the file handle w/o mmap.
|
||||
with path.open('rb') as fid:
|
||||
return np.load(fid)
|
||||
|
||||
|
||||
def _build_coo(edges_data, use_boolean=False):
|
||||
if use_boolean:
|
||||
mat_coo = scipy.sparse.coo_matrix(
|
||||
(np.ones_like(edges_data[1, :],
|
||||
dtype=bool), (edges_data[0, :], edges_data[1, :])))
|
||||
else:
|
||||
mat_coo = scipy.sparse.coo_matrix(
|
||||
(edges_data[1, :], (edges_data[0, :], edges_data[1, :])))
|
||||
return mat_coo
|
||||
|
||||
|
||||
def _get_output_paths(directory, content_names, use_boolean):
|
||||
boolean_str = '_b' if use_boolean else ''
|
||||
transpose_str = '_t' if len(set(content_names)) == 1 else ''
|
||||
output_prefix = '_'.join(content_names)
|
||||
output_prefix_t = '_'.join(content_names[::-1])
|
||||
output_filename = f'{output_prefix}{boolean_str}.npz'
|
||||
output_filename_t = f'{output_prefix_t}{boolean_str}{transpose_str}.npz'
|
||||
output_path = directory / output_filename
|
||||
output_path_t = directory / output_filename_t
|
||||
return output_path, output_path_t
|
||||
|
||||
|
||||
def _write_csr(path, csr):
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open('wb') as fid:
|
||||
scipy.sparse.save_npz(fid, csr)
|
||||
|
||||
|
||||
def main(argv):
|
||||
if len(argv) > 1:
|
||||
raise app.UsageError('Too many command-line arguments.')
|
||||
|
||||
raw_data_dir = Path(FLAGS.data_root) / data_utils.RAW_DIR
|
||||
preprocessed_dir = Path(FLAGS.data_root) / data_utils.PREPROCESSED_DIR
|
||||
|
||||
for input_filename, parameters in _DATA_FILES_AND_PARAMETERS.items():
|
||||
input_path = raw_data_dir / input_filename
|
||||
output_path, output_path_t = _get_output_paths(preprocessed_dir,
|
||||
**parameters)
|
||||
if FLAGS.skip_existing and output_path.exists() and output_path_t.exists():
|
||||
# If both files exist, skip. When only one exists, that's handled below.
|
||||
logging.info(
|
||||
'%s and %s exist: skipping. Use flag `--skip_existing=False`'
|
||||
'to force overwrite existing.', output_path, output_path_t)
|
||||
continue
|
||||
logging.info('Reading edge data from: %s', input_path)
|
||||
edge_data = _read_edge_data(input_path)
|
||||
logging.info('Building CSR matrices')
|
||||
mat_coo = _build_coo(edge_data, use_boolean=parameters['use_boolean'])
|
||||
# Convert matrices to CSR and write to disk.
|
||||
if not FLAGS.skip_existing or not output_path.exists():
|
||||
logging.info('Writing CSR matrix to: %s', output_path)
|
||||
mat_csr = mat_coo.tocsr()
|
||||
_write_csr(output_path, mat_csr)
|
||||
del mat_csr # Free up memory asap.
|
||||
else:
|
||||
logging.info(
|
||||
'%s exists: skipping. Use flag `--skip_existing=False`'
|
||||
'to force overwrite existing.', output_path)
|
||||
if not FLAGS.skip_existing or not output_path_t.exists():
|
||||
logging.info('Writing (transposed) CSR matrix to: %s', output_path_t)
|
||||
mat_csr_t = mat_coo.transpose().tocsr()
|
||||
_write_csr(output_path_t, mat_csr_t)
|
||||
del mat_csr_t # Free up memory asap.
|
||||
else:
|
||||
logging.info(
|
||||
'%s exists: skipping. Use flag `--skip_existing=False`'
|
||||
'to force overwrite existing.', output_path_t)
|
||||
del mat_coo # Free up memory asap.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,491 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataset utilities."""
|
||||
|
||||
import functools
|
||||
import pathlib
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from absl import logging
|
||||
from graph_nets import graphs as tf_graphs
|
||||
from graph_nets import utils_tf
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
import tensorflow as tf
|
||||
import tqdm
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
import sub_sampler
|
||||
|
||||
Path = pathlib.Path
|
||||
|
||||
NUM_PAPERS = 121751666
|
||||
NUM_AUTHORS = 122383112
|
||||
NUM_INSTITUTIONS = 25721
|
||||
EMBEDDING_SIZE = 768
|
||||
NUM_CLASSES = 153
|
||||
|
||||
NUM_NODES = NUM_PAPERS + NUM_AUTHORS + NUM_INSTITUTIONS
|
||||
NUM_EDGES = 1_728_364_232
|
||||
assert NUM_NODES == 244_160_499
|
||||
|
||||
NUM_K_FOLD_SPLITS = 10
|
||||
|
||||
|
||||
OFFSETS = {
|
||||
"paper": 0,
|
||||
"author": NUM_PAPERS,
|
||||
"institution": NUM_PAPERS + NUM_AUTHORS,
|
||||
}
|
||||
|
||||
|
||||
SIZES = {
|
||||
"paper": NUM_PAPERS,
|
||||
"author": NUM_AUTHORS,
|
||||
"institution": NUM_INSTITUTIONS
|
||||
}
|
||||
|
||||
RAW_DIR = Path("raw")
|
||||
PREPROCESSED_DIR = Path("preprocessed")
|
||||
|
||||
RAW_NODE_FEATURES_FILENAME = RAW_DIR / "node_feat.npy"
|
||||
RAW_NODE_LABELS_FILENAME = RAW_DIR / "node_label.npy"
|
||||
RAW_NODE_YEAR_FILENAME = RAW_DIR / "node_year.npy"
|
||||
|
||||
TRAIN_INDEX_FILENAME = RAW_DIR / "train_idx.npy"
|
||||
VALID_INDEX_FILENAME = RAW_DIR / "train_idx.npy"
|
||||
TEST_INDEX_FILENAME = RAW_DIR / "train_idx.npy"
|
||||
|
||||
EDGES_PAPER_PAPER_B = PREPROCESSED_DIR / "paper_paper_b.npz"
|
||||
EDGES_PAPER_PAPER_B_T = PREPROCESSED_DIR / "paper_paper_b_t.npz"
|
||||
EDGES_AUTHOR_INSTITUTION = PREPROCESSED_DIR / "author_institution.npz"
|
||||
EDGES_INSTITUTION_AUTHOR = PREPROCESSED_DIR / "institution_author.npz"
|
||||
EDGES_AUTHOR_PAPER = PREPROCESSED_DIR / "author_paper.npz"
|
||||
EDGES_PAPER_AUTHOR = PREPROCESSED_DIR / "paper_author.npz"
|
||||
|
||||
PCA_PAPER_FEATURES_FILENAME = PREPROCESSED_DIR / "paper_feat_pca_129.npy"
|
||||
PCA_AUTHOR_FEATURES_FILENAME = (
|
||||
PREPROCESSED_DIR / "author_feat_from_paper_feat_pca_129.npy")
|
||||
PCA_INSTITUTION_FEATURES_FILENAME = (
|
||||
PREPROCESSED_DIR / "institution_feat_from_paper_feat_pca_129.npy")
|
||||
PCA_MERGED_FEATURES_FILENAME = (
|
||||
PREPROCESSED_DIR / "merged_feat_from_paper_feat_pca_129.npy")
|
||||
|
||||
NEIGHBOR_INDICES_FILENAME = PREPROCESSED_DIR / "neighbor_indices.npy"
|
||||
NEIGHBOR_DISTANCES_FILENAME = PREPROCESSED_DIR / "neighbor_distances.npy"
|
||||
|
||||
FUSED_NODE_LABELS_FILENAME = PREPROCESSED_DIR / "fused_node_labels.npy"
|
||||
FUSED_PAPER_EDGES_FILENAME = PREPROCESSED_DIR / "fused_paper_edges.npz"
|
||||
FUSED_PAPER_EDGES_T_FILENAME = PREPROCESSED_DIR / "fused_paper_edges_t.npz"
|
||||
|
||||
K_FOLD_SPLITS_DIR = Path("k_fold_splits")
|
||||
|
||||
|
||||
def get_raw_directory(data_root):
|
||||
return Path(data_root) / "raw"
|
||||
|
||||
|
||||
def get_preprocessed_directory(data_root):
|
||||
return Path(data_root) / "preprocessed"
|
||||
|
||||
|
||||
def _log_path_decorator(fn):
|
||||
def _decorated_fn(path, **kwargs):
|
||||
logging.info("Loading %s", path)
|
||||
output = fn(path, **kwargs)
|
||||
logging.info("Finish loading %s", path)
|
||||
return output
|
||||
return _decorated_fn
|
||||
|
||||
|
||||
@_log_path_decorator
|
||||
def load_csr(path, debug=False):
|
||||
if debug:
|
||||
# Dummy matrix for debugging.
|
||||
return sp.csr_matrix(np.zeros([10, 10]))
|
||||
return sp.load_npz(str(path))
|
||||
|
||||
|
||||
@_log_path_decorator
|
||||
def load_npy(path):
|
||||
return np.load(str(path))
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_arrays(data_root="/data/",
|
||||
use_fused_node_labels=True,
|
||||
use_fused_node_adjacencies=True,
|
||||
return_pca_embeddings=True,
|
||||
k_fold_split_id=None,
|
||||
return_adjacencies=True,
|
||||
use_dummy_adjacencies=False):
|
||||
"""Returns all arrays needed for training."""
|
||||
logging.info("Starting to get files")
|
||||
|
||||
data_root = Path(data_root)
|
||||
|
||||
array_dict = {}
|
||||
array_dict["paper_year"] = load_npy(data_root / RAW_NODE_YEAR_FILENAME)
|
||||
|
||||
if k_fold_split_id is None:
|
||||
train_indices = load_npy(data_root / TRAIN_INDEX_FILENAME)
|
||||
valid_indices = load_npy(data_root / VALID_INDEX_FILENAME)
|
||||
else:
|
||||
train_indices, valid_indices = get_train_and_valid_idx_for_split(
|
||||
k_fold_split_id, num_splits=NUM_K_FOLD_SPLITS,
|
||||
root_path=data_root / K_FOLD_SPLITS_DIR)
|
||||
|
||||
array_dict["train_indices"] = train_indices
|
||||
array_dict["valid_indices"] = valid_indices
|
||||
array_dict["test_indices"] = load_npy(data_root / TEST_INDEX_FILENAME)
|
||||
|
||||
if use_fused_node_labels:
|
||||
array_dict["paper_label"] = load_npy(data_root / FUSED_NODE_LABELS_FILENAME)
|
||||
else:
|
||||
array_dict["paper_label"] = load_npy(data_root / RAW_NODE_LABELS_FILENAME)
|
||||
|
||||
if return_adjacencies:
|
||||
logging.info("Starting to get adjacencies.")
|
||||
if use_fused_node_adjacencies:
|
||||
paper_paper_index = load_csr(
|
||||
data_root / FUSED_PAPER_EDGES_FILENAME, debug=use_dummy_adjacencies)
|
||||
paper_paper_index_t = load_csr(
|
||||
data_root / FUSED_PAPER_EDGES_T_FILENAME, debug=use_dummy_adjacencies)
|
||||
else:
|
||||
paper_paper_index = load_csr(
|
||||
data_root / EDGES_PAPER_PAPER_B, debug=use_dummy_adjacencies)
|
||||
paper_paper_index_t = load_csr(
|
||||
data_root / EDGES_PAPER_PAPER_B_T, debug=use_dummy_adjacencies)
|
||||
array_dict.update(
|
||||
dict(
|
||||
author_institution_index=load_csr(
|
||||
data_root / EDGES_AUTHOR_INSTITUTION,
|
||||
debug=use_dummy_adjacencies),
|
||||
institution_author_index=load_csr(
|
||||
data_root / EDGES_INSTITUTION_AUTHOR,
|
||||
debug=use_dummy_adjacencies),
|
||||
author_paper_index=load_csr(
|
||||
data_root / EDGES_AUTHOR_PAPER, debug=use_dummy_adjacencies),
|
||||
paper_author_index=load_csr(
|
||||
data_root / EDGES_PAPER_AUTHOR, debug=use_dummy_adjacencies),
|
||||
paper_paper_index=paper_paper_index,
|
||||
paper_paper_index_t=paper_paper_index_t,
|
||||
))
|
||||
|
||||
if return_pca_embeddings:
|
||||
array_dict["bert_pca_129"] = np.load(
|
||||
data_root / PCA_MERGED_FEATURES_FILENAME, mmap_mode="r")
|
||||
assert array_dict["bert_pca_129"].shape == (NUM_NODES, 129)
|
||||
|
||||
logging.info("Finish getting files")
|
||||
|
||||
# pytype: disable=attribute-error
|
||||
assert array_dict["paper_year"].shape[0] == NUM_PAPERS
|
||||
assert array_dict["paper_label"].shape[0] == NUM_PAPERS
|
||||
|
||||
if return_adjacencies and not use_dummy_adjacencies:
|
||||
array_dict = _fix_adjacency_shapes(array_dict)
|
||||
|
||||
assert array_dict["paper_author_index"].shape == (NUM_PAPERS, NUM_AUTHORS)
|
||||
assert array_dict["author_paper_index"].shape == (NUM_AUTHORS, NUM_PAPERS)
|
||||
assert array_dict["paper_paper_index"].shape == (NUM_PAPERS, NUM_PAPERS)
|
||||
assert array_dict["paper_paper_index_t"].shape == (NUM_PAPERS, NUM_PAPERS)
|
||||
assert array_dict["institution_author_index"].shape == (
|
||||
NUM_INSTITUTIONS, NUM_AUTHORS)
|
||||
assert array_dict["author_institution_index"].shape == (
|
||||
NUM_AUTHORS, NUM_INSTITUTIONS)
|
||||
|
||||
# pytype: enable=attribute-error
|
||||
|
||||
return array_dict
|
||||
|
||||
|
||||
def add_nodes_year(graph, paper_year):
|
||||
nodes = graph.nodes.copy()
|
||||
indices = nodes["index"]
|
||||
year = paper_year[np.minimum(indices, paper_year.shape[0] - 1)].copy()
|
||||
year[nodes["type"] != 0] = 1900
|
||||
nodes["year"] = year
|
||||
return graph._replace(nodes=nodes)
|
||||
|
||||
|
||||
def add_nodes_label(graph, paper_label):
|
||||
nodes = graph.nodes.copy()
|
||||
indices = nodes["index"]
|
||||
label = paper_label[np.minimum(indices, paper_label.shape[0] - 1)]
|
||||
label[nodes["type"] != 0] = 0
|
||||
nodes["label"] = label
|
||||
return graph._replace(nodes=nodes)
|
||||
|
||||
|
||||
def add_nodes_embedding_from_array(graph, array):
|
||||
"""Adds embeddings from the sstable_service for the indices."""
|
||||
nodes = graph.nodes.copy()
|
||||
indices = nodes["index"]
|
||||
embedding_indices = indices.copy()
|
||||
embedding_indices[nodes["type"] == 1] += NUM_PAPERS
|
||||
embedding_indices[nodes["type"] == 2] += NUM_PAPERS + NUM_AUTHORS
|
||||
|
||||
# Gather the embeddings for the indices.
|
||||
nodes["features"] = array[embedding_indices]
|
||||
return graph._replace(nodes=nodes)
|
||||
|
||||
|
||||
def get_graph_subsampling_dataset(
|
||||
prefix, arrays, shuffle_indices, ratio_unlabeled_data_to_labeled_data,
|
||||
max_nodes, max_edges,
|
||||
**subsampler_kwargs):
|
||||
"""Returns tf_dataset for online sampling."""
|
||||
|
||||
def generator():
|
||||
labeled_indices = arrays[f"{prefix}_indices"]
|
||||
if ratio_unlabeled_data_to_labeled_data > 0:
|
||||
num_unlabeled_data_to_add = int(ratio_unlabeled_data_to_labeled_data *
|
||||
labeled_indices.shape[0])
|
||||
unlabeled_indices = np.random.choice(
|
||||
NUM_PAPERS, size=num_unlabeled_data_to_add, replace=False)
|
||||
root_node_indices = np.concatenate([labeled_indices, unlabeled_indices])
|
||||
else:
|
||||
root_node_indices = labeled_indices
|
||||
if shuffle_indices:
|
||||
root_node_indices = root_node_indices.copy()
|
||||
np.random.shuffle(root_node_indices)
|
||||
|
||||
for index in root_node_indices:
|
||||
graph = sub_sampler.subsample_graph(
|
||||
index,
|
||||
arrays["author_institution_index"],
|
||||
arrays["institution_author_index"],
|
||||
arrays["author_paper_index"],
|
||||
arrays["paper_author_index"],
|
||||
arrays["paper_paper_index"],
|
||||
arrays["paper_paper_index_t"],
|
||||
paper_years=arrays["paper_year"],
|
||||
max_nodes=max_nodes,
|
||||
max_edges=max_edges,
|
||||
**subsampler_kwargs)
|
||||
|
||||
graph = add_nodes_label(graph, arrays["paper_label"])
|
||||
graph = add_nodes_year(graph, arrays["paper_year"])
|
||||
graph = tf_graphs.GraphsTuple(*graph)
|
||||
yield graph
|
||||
|
||||
sample_graph = next(generator())
|
||||
|
||||
return tf.data.Dataset.from_generator(
|
||||
generator,
|
||||
output_signature=utils_tf.specs_from_graphs_tuple(sample_graph))
|
||||
|
||||
|
||||
def paper_features_to_author_features(
|
||||
author_paper_index, paper_features):
|
||||
"""Averages paper features to authors."""
|
||||
assert paper_features.shape[0] == NUM_PAPERS
|
||||
assert author_paper_index.shape[0] == NUM_AUTHORS
|
||||
author_features = np.zeros(
|
||||
[NUM_AUTHORS, paper_features.shape[1]], dtype=paper_features.dtype)
|
||||
for author_i in range(NUM_AUTHORS):
|
||||
paper_indices = author_paper_index[author_i].indices
|
||||
author_features[author_i] = paper_features[paper_indices].mean(
|
||||
axis=0, dtype=np.float32)
|
||||
if author_i % 10000 == 0:
|
||||
logging.info("%d/%d", author_i, NUM_AUTHORS)
|
||||
return author_features
|
||||
|
||||
|
||||
def author_features_to_institution_features(
|
||||
institution_author_index, author_features):
|
||||
"""Averages author features to institutions."""
|
||||
assert author_features.shape[0] == NUM_AUTHORS
|
||||
assert institution_author_index.shape[0] == NUM_INSTITUTIONS
|
||||
institution_features = np.zeros(
|
||||
[NUM_INSTITUTIONS, author_features.shape[1]], dtype=author_features.dtype)
|
||||
for institution_i in range(NUM_INSTITUTIONS):
|
||||
author_indices = institution_author_index[institution_i].indices
|
||||
institution_features[institution_i] = author_features[
|
||||
author_indices].mean(axis=0, dtype=np.float32)
|
||||
if institution_i % 10000 == 0:
|
||||
logging.info("%d/%d", institution_i, NUM_INSTITUTIONS)
|
||||
return institution_features
|
||||
|
||||
|
||||
def generate_fused_paper_adjacency_matrix(neighbor_indices, neighbor_distances,
|
||||
paper_paper_csr):
|
||||
"""Generates fused adjacency matrix for identical nodes."""
|
||||
# First construct set of identical node indices.
|
||||
# NOTE: Since we take only top K=26 identical pairs for each node, this is not
|
||||
# actually exhaustive. Also, if A and B are equal, and B and C are equal,
|
||||
# this method would not necessarily detect A and C being equal.
|
||||
# However, this should capture almost all cases.
|
||||
logging.info("Generating fused paper adjacency matrix")
|
||||
eps = 0.0
|
||||
mask = ((neighbor_indices != np.mgrid[:neighbor_indices.shape[0], :1]) &
|
||||
(neighbor_distances <= eps))
|
||||
identical_pairs = list(map(tuple, np.nonzero(mask)))
|
||||
del mask
|
||||
|
||||
# Have a csc version for fast column access.
|
||||
paper_paper_csc = paper_paper_csr.tocsc()
|
||||
|
||||
# Construct new matrix as coo, starting off with original rows/cols.
|
||||
paper_paper_coo = paper_paper_csr.tocoo()
|
||||
new_rows = [paper_paper_coo.row]
|
||||
new_cols = [paper_paper_coo.col]
|
||||
|
||||
for pair in tqdm.tqdm(identical_pairs):
|
||||
# STEP ONE: First merge papers being cited by the pair.
|
||||
# Add edges from second paper, to all papers cited by first paper.
|
||||
cited_by_first = paper_paper_csr.getrow(pair[0]).nonzero()[1]
|
||||
if cited_by_first.shape[0] > 0:
|
||||
new_rows.append(pair[1] * np.ones_like(cited_by_first))
|
||||
new_cols.append(cited_by_first)
|
||||
|
||||
# Add edges from first paper, to all papers cited by second paper.
|
||||
cited_by_second = paper_paper_csr.getrow(pair[1]).nonzero()[1]
|
||||
if cited_by_second.shape[0] > 0:
|
||||
new_rows.append(pair[0] * np.ones_like(cited_by_second))
|
||||
new_cols.append(cited_by_second)
|
||||
|
||||
# STEP TWO: Then merge papers that cite the pair.
|
||||
# Add edges to second paper, from all papers citing the first paper.
|
||||
citing_first = paper_paper_csc.getcol(pair[0]).nonzero()[0]
|
||||
if citing_first.shape[0] > 0:
|
||||
new_rows.append(citing_first)
|
||||
new_cols.append(pair[1] * np.ones_like(citing_first))
|
||||
|
||||
# Add edges to first paper, from all papers citing the second paper.
|
||||
citing_second = paper_paper_csc.getcol(pair[1]).nonzero()[0]
|
||||
if citing_second.shape[0] > 0:
|
||||
new_rows.append(citing_second)
|
||||
new_cols.append(pair[0] * np.ones_like(citing_second))
|
||||
|
||||
logging.info("Done with adjacency loop")
|
||||
paper_paper_coo_shape = paper_paper_coo.shape
|
||||
del paper_paper_csr
|
||||
del paper_paper_csc
|
||||
del paper_paper_coo
|
||||
# All done; now concatenate everything together and form new matrix.
|
||||
new_rows = np.concatenate(new_rows)
|
||||
new_cols = np.concatenate(new_cols)
|
||||
return sp.coo_matrix(
|
||||
(np.ones_like(new_rows, dtype=np.bool), (new_rows, new_cols)),
|
||||
shape=paper_paper_coo_shape).tocsr()
|
||||
|
||||
|
||||
def generate_k_fold_splits(
|
||||
train_idx, valid_idx, output_path, num_splits=NUM_K_FOLD_SPLITS):
|
||||
"""Generates splits adding fractions of the validation split to training."""
|
||||
output_path = Path(output_path)
|
||||
np.random.seed(42)
|
||||
valid_idx = np.random.permutation(valid_idx)
|
||||
# Split into `num_parts` (almost) identically sized arrays.
|
||||
valid_idx_parts = np.array_split(valid_idx, num_splits)
|
||||
|
||||
for i in range(num_splits):
|
||||
# Add all but the i'th subpart to training set.
|
||||
new_train_idx = np.concatenate(
|
||||
[train_idx, *valid_idx_parts[:i], *valid_idx_parts[i+1:]])
|
||||
# i'th subpart is validation set.
|
||||
new_valid_idx = valid_idx_parts[i]
|
||||
train_path = output_path / f"train_idx_{i}_{num_splits}.npy"
|
||||
valid_path = output_path / f"valid_idx_{i}_{num_splits}.npy"
|
||||
np.save(train_path, new_train_idx)
|
||||
np.save(valid_path, new_valid_idx)
|
||||
logging.info("Saved: %s", train_path)
|
||||
logging.info("Saved: %s", valid_path)
|
||||
|
||||
|
||||
def get_train_and_valid_idx_for_split(
|
||||
split_id: int,
|
||||
num_splits: int,
|
||||
root_path: str,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""Returns train and valid indices for given split."""
|
||||
new_train_idx = load_npy(f"{root_path}/train_idx_{split_id}_{num_splits}.npy")
|
||||
new_valid_idx = load_npy(f"{root_path}/valid_idx_{split_id}_{num_splits}.npy")
|
||||
return new_train_idx, new_valid_idx
|
||||
|
||||
|
||||
def generate_fused_node_labels(neighbor_indices, neighbor_distances,
|
||||
node_labels, train_indices, valid_indices,
|
||||
test_indices):
|
||||
"""Generates fused adjacency matrix for identical nodes."""
|
||||
logging.info("Generating fused node labels")
|
||||
valid_indices = set(valid_indices.tolist())
|
||||
test_indices = set(test_indices.tolist())
|
||||
valid_or_test_indices = valid_indices | test_indices
|
||||
|
||||
train_indices = train_indices[train_indices < neighbor_indices.shape[0]]
|
||||
# Go through list of all pairs where one node is in training set, and
|
||||
for i in tqdm.tqdm(train_indices):
|
||||
for j in range(neighbor_indices.shape[1]):
|
||||
other_index = neighbor_indices[i][j]
|
||||
# if the other is not a validation or test node,
|
||||
if other_index in valid_or_test_indices:
|
||||
continue
|
||||
# and they are identical,
|
||||
if neighbor_distances[i][j] == 0:
|
||||
# assign the label of the training node to the other node
|
||||
node_labels[other_index] = node_labels[i]
|
||||
|
||||
return node_labels
|
||||
|
||||
|
||||
def _pad_to_shape(
|
||||
sparse_csr_matrix: sp.csr_matrix,
|
||||
output_shape: Tuple[int, int]) -> sp.csr_matrix:
|
||||
"""Pads a csr sparse matrix to the given shape."""
|
||||
|
||||
# We should not try to expand anything smaller.
|
||||
assert np.all(sparse_csr_matrix.shape <= output_shape)
|
||||
|
||||
# Maybe it already has the right shape.
|
||||
if sparse_csr_matrix.shape == output_shape:
|
||||
return sparse_csr_matrix
|
||||
|
||||
# Append as many indptr elements as we need to match the leading size,
|
||||
# This is achieved by just padding with copies of the last indptr element.
|
||||
required_padding = output_shape[0] - sparse_csr_matrix.shape[0]
|
||||
updated_indptr = np.concatenate(
|
||||
[sparse_csr_matrix.indptr] +
|
||||
[sparse_csr_matrix.indptr[-1:]] * required_padding,
|
||||
axis=0)
|
||||
|
||||
# The change in trailing size does not have structural implications, it just
|
||||
# determines the highest possible value for the indices, so it is sufficient
|
||||
# to just pass the new output shape, with the correct trailing size.
|
||||
return sp.csr.csr_matrix(
|
||||
(sparse_csr_matrix.data,
|
||||
sparse_csr_matrix.indices,
|
||||
updated_indptr),
|
||||
shape=output_shape)
|
||||
|
||||
|
||||
def _fix_adjacency_shapes(
|
||||
arrays: Dict[str, sp.csr.csr_matrix],
|
||||
) -> Dict[str, sp.csr.csr_matrix]:
|
||||
"""Fixes the shapes of the adjacency matrices."""
|
||||
arrays = arrays.copy()
|
||||
for key in ["author_institution_index",
|
||||
"author_paper_index",
|
||||
"paper_paper_index",
|
||||
"institution_author_index",
|
||||
"paper_author_index",
|
||||
"paper_paper_index_t"]:
|
||||
type_sender = key.split("_")[0]
|
||||
type_receiver = key.split("_")[1]
|
||||
arrays[key] = _pad_to_shape(
|
||||
arrays[key], output_shape=(SIZES[type_sender], SIZES[type_receiver]))
|
||||
return arrays
|
||||
@@ -0,0 +1,297 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MAG240M-LSC datasets."""
|
||||
|
||||
import threading
|
||||
from typing import NamedTuple, Optional
|
||||
|
||||
|
||||
import jax
|
||||
import jraph
|
||||
from ml_collections import config_dict
|
||||
import numpy as np
|
||||
import tensorflow.compat.v2 as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
# pytype: disable=import-error
|
||||
import batching_utils
|
||||
import data_utils
|
||||
|
||||
|
||||
# We only want to load these arrays once for all threads.
|
||||
# `get_arrays` uses an LRU cache which is not thread safe.
|
||||
LOADING_RAW_ARRAYS_LOCK = threading.Lock()
|
||||
|
||||
NUM_CLASSES = data_utils.NUM_CLASSES
|
||||
|
||||
|
||||
_MAX_DEPTH_IN_SUBGRAPH = 3
|
||||
|
||||
|
||||
class Batch(NamedTuple):
|
||||
"""NamedTuple to represent batches of data."""
|
||||
graph: jraph.GraphsTuple
|
||||
node_labels: np.ndarray
|
||||
label_mask: np.ndarray
|
||||
central_node_mask: np.ndarray
|
||||
node_indices: np.ndarray
|
||||
absolute_node_indices: np.ndarray
|
||||
|
||||
|
||||
def build_dataset_iterator(
|
||||
data_root: str,
|
||||
split: str,
|
||||
dynamic_batch_size_config: config_dict.ConfigDict,
|
||||
online_subsampling_kwargs: dict, # pylint: disable=g-bare-generic
|
||||
debug: bool = False,
|
||||
is_training: bool = True,
|
||||
k_fold_split_id: Optional[int] = None,
|
||||
ratio_unlabeled_data_to_labeled_data: float = 0.0,
|
||||
use_all_labels_when_not_training: bool = False,
|
||||
use_dummy_adjacencies: bool = False,
|
||||
):
|
||||
"""Returns an iterator over Batches from the dataset."""
|
||||
|
||||
if split == 'test':
|
||||
use_all_labels_when_not_training = True
|
||||
|
||||
if not is_training:
|
||||
ratio_unlabeled_data_to_labeled_data = 0.0
|
||||
|
||||
# Load the master data arrays.
|
||||
with LOADING_RAW_ARRAYS_LOCK:
|
||||
array_dict = data_utils.get_arrays(
|
||||
data_root, k_fold_split_id=k_fold_split_id,
|
||||
use_dummy_adjacencies=use_dummy_adjacencies)
|
||||
|
||||
node_labels = array_dict['paper_label'].reshape(-1)
|
||||
train_indices = array_dict['train_indices'].astype(np.int32)
|
||||
is_train_index = np.zeros(node_labels.shape[0], dtype=np.int32)
|
||||
is_train_index[train_indices] = 1
|
||||
valid_indices = array_dict['valid_indices'].astype(np.int32)
|
||||
is_valid_index = np.zeros(node_labels.shape[0], dtype=np.int32)
|
||||
is_valid_index[valid_indices] = 1
|
||||
is_train_or_valid_index = is_train_index + is_valid_index
|
||||
|
||||
def sstable_to_intermediate_graph(graph):
|
||||
indices = tf.cast(graph.nodes['index'], tf.int32)
|
||||
first_index = indices[..., 0]
|
||||
|
||||
# Add an additional absolute index, but adding offsets to authors, and
|
||||
# institution indices.
|
||||
absolute_index = graph.nodes['index']
|
||||
is_author = graph.nodes['type'] == 1
|
||||
absolute_index = tf.where(
|
||||
is_author, absolute_index + data_utils.NUM_PAPERS, absolute_index)
|
||||
is_institution = graph.nodes['type'] == 2
|
||||
absolute_index = tf.where(
|
||||
is_institution,
|
||||
absolute_index + data_utils.NUM_PAPERS + data_utils.NUM_AUTHORS,
|
||||
absolute_index)
|
||||
|
||||
is_same_as_central_node = tf.math.equal(indices, first_index)
|
||||
input_nodes = graph.nodes
|
||||
graph = graph._replace(
|
||||
nodes={
|
||||
'one_hot_type':
|
||||
tf.one_hot(tf.cast(input_nodes['type'], tf.int32), 3),
|
||||
'one_hot_depth':
|
||||
tf.one_hot(
|
||||
tf.cast(input_nodes['depth'], tf.int32),
|
||||
_MAX_DEPTH_IN_SUBGRAPH),
|
||||
'year':
|
||||
tf.expand_dims(input_nodes['year'], axis=-1),
|
||||
'label':
|
||||
tf.one_hot(
|
||||
tf.cast(input_nodes['label'], tf.int32),
|
||||
NUM_CLASSES),
|
||||
'is_same_as_central_node':
|
||||
is_same_as_central_node,
|
||||
# Only first node in graph has a valid label.
|
||||
'is_central_node':
|
||||
tf.one_hot(0,
|
||||
tf.shape(input_nodes['label'])[0]),
|
||||
'index':
|
||||
input_nodes['index'],
|
||||
'absolute_index': absolute_index,
|
||||
},
|
||||
globals=tf.expand_dims(graph.globals, axis=-1),
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
ds = data_utils.get_graph_subsampling_dataset(
|
||||
split,
|
||||
array_dict,
|
||||
shuffle_indices=is_training,
|
||||
ratio_unlabeled_data_to_labeled_data=ratio_unlabeled_data_to_labeled_data,
|
||||
max_nodes=dynamic_batch_size_config.n_node - 1, # Keep space for pads.
|
||||
max_edges=dynamic_batch_size_config.n_edge,
|
||||
**online_subsampling_kwargs)
|
||||
if debug:
|
||||
ds = ds.take(50)
|
||||
ds = ds.map(
|
||||
sstable_to_intermediate_graph,
|
||||
num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
||||
|
||||
if is_training:
|
||||
ds = ds.shard(jax.process_count(), jax.process_index())
|
||||
ds = ds.shuffle(buffer_size=1 if debug else 128)
|
||||
ds = ds.repeat()
|
||||
ds = ds.prefetch(1 if debug else tf.data.experimental.AUTOTUNE)
|
||||
np_ds = iter(tfds.as_numpy(ds))
|
||||
batched_np_ds = batching_utils.dynamically_batch(
|
||||
np_ds,
|
||||
**dynamic_batch_size_config,
|
||||
)
|
||||
|
||||
def intermediate_graph_to_batch(graph):
|
||||
central_node_mask = graph.nodes['is_central_node']
|
||||
label = graph.nodes['label']
|
||||
node_indices = graph.nodes['index']
|
||||
absolute_indices = graph.nodes['absolute_index']
|
||||
|
||||
### Construct label as a feature for non-central nodes.
|
||||
# First do a lookup with node indices, with a np.minimum to ensure we do not
|
||||
# index out of bounds due to num_authors being larger than num_papers.
|
||||
is_same_as_central_node = graph.nodes['is_same_as_central_node']
|
||||
capped_indices = np.minimum(node_indices, node_labels.shape[0] - 1)
|
||||
label_as_feature = node_labels[capped_indices]
|
||||
# Nodes which are not in train set should get `num_classes` label.
|
||||
# Nodes in test set or non-arXiv nodes have -1 or nan labels.
|
||||
|
||||
# Mask out invalid labels and non-papers.
|
||||
use_label_as_feature = np.logical_and(label_as_feature >= 0,
|
||||
graph.nodes['one_hot_type'][..., 0])
|
||||
if split == 'train' or not use_all_labels_when_not_training:
|
||||
# Mask out validation papers and non-arxiv papers who
|
||||
# got labels from fusing with arxiv papers.
|
||||
use_label_as_feature = np.logical_and(is_train_index[capped_indices],
|
||||
use_label_as_feature)
|
||||
label_as_feature = np.where(use_label_as_feature, label_as_feature,
|
||||
NUM_CLASSES)
|
||||
# Mask out central node label in case it appears again.
|
||||
label_as_feature = np.where(is_same_as_central_node, NUM_CLASSES,
|
||||
label_as_feature)
|
||||
# Nodes which are not papers get `NUM_CLASSES+1` label.
|
||||
label_as_feature = np.where(graph.nodes['one_hot_type'][..., 0],
|
||||
label_as_feature, NUM_CLASSES+1)
|
||||
|
||||
nodes = {
|
||||
'label_as_feature': label_as_feature,
|
||||
'year': graph.nodes['year'],
|
||||
'bitstring_year': _get_bitstring_year_representation(
|
||||
graph.nodes['year']),
|
||||
'one_hot_type': graph.nodes['one_hot_type'],
|
||||
'one_hot_depth': graph.nodes['one_hot_depth'],
|
||||
}
|
||||
|
||||
graph = graph._replace(
|
||||
nodes=nodes,
|
||||
globals={},
|
||||
)
|
||||
is_train_or_valid_node = np.logical_and(
|
||||
is_train_or_valid_index[capped_indices],
|
||||
graph.nodes['one_hot_type'][..., 0])
|
||||
if is_training:
|
||||
label_mask = np.logical_and(central_node_mask, is_train_or_valid_node)
|
||||
else:
|
||||
# `label_mask` is used to index into valid central nodes by prediction
|
||||
# calculator. Since that computation is only done when not training, and
|
||||
# at that time we are guaranteed all central nodes have valid labels,
|
||||
# we just set label_mask = central_node_mask when not training.
|
||||
label_mask = central_node_mask
|
||||
batch = Batch(
|
||||
graph=graph,
|
||||
node_labels=label,
|
||||
central_node_mask=central_node_mask,
|
||||
label_mask=label_mask,
|
||||
node_indices=node_indices,
|
||||
absolute_node_indices=absolute_indices)
|
||||
|
||||
# Transform integers into one-hots.
|
||||
batch = _add_one_hot_features_to_batch(batch)
|
||||
|
||||
# Gather PCA features.
|
||||
return _add_embeddings_to_batch(batch, array_dict['bert_pca_129'])
|
||||
|
||||
batch_list = []
|
||||
for batch in batched_np_ds:
|
||||
with jax.profiler.StepTraceAnnotation('batch_postprocessing'):
|
||||
batch = intermediate_graph_to_batch(batch)
|
||||
if is_training:
|
||||
batch_list.append(batch)
|
||||
if len(batch_list) == jax.local_device_count():
|
||||
yield jax.device_put_sharded(batch_list, jax.local_devices())
|
||||
batch_list = []
|
||||
else:
|
||||
yield batch
|
||||
|
||||
|
||||
def _get_bitstring_year_representation(year: np.ndarray):
|
||||
"""Return year as bitstring."""
|
||||
min_year = 1900
|
||||
max_training_year = 2018
|
||||
offseted_year = np.minimum(year, max_training_year) - min_year
|
||||
return np.unpackbits(offseted_year.astype(np.uint8), axis=-1)
|
||||
|
||||
|
||||
def _np_one_hot(targets: np.ndarray, nb_classes: int):
|
||||
res = np.zeros(targets.shape + (nb_classes,), dtype=np.float16)
|
||||
np.put_along_axis(res, targets.astype(np.int32)[..., None], 1.0, axis=-1)
|
||||
return res
|
||||
|
||||
|
||||
def _get_one_hot_year_representation(
|
||||
year: np.ndarray,
|
||||
one_hot_type: np.ndarray,
|
||||
):
|
||||
"""Returns good representation for year."""
|
||||
# Bucket edges found based on quantiles to bucket into 20 equal sized buckets.
|
||||
bucket_edges = np.array([
|
||||
1964, 1975, 1983, 1989, 1994, 1998, 2001, 2004,
|
||||
2006, 2008, 2009, 2011, 2012, 2013, 2014, 2016,
|
||||
2017, # 2018, 2019, 2020 contain last-year-of-train, eval, test nodes
|
||||
])
|
||||
year = np.squeeze(year, axis=-1)
|
||||
year_id = np.searchsorted(bucket_edges, year)
|
||||
is_paper = one_hot_type[..., 0]
|
||||
bucket_id_for_non_paper = len(bucket_edges) + 1
|
||||
bucket_id = np.where(is_paper, year_id, bucket_id_for_non_paper)
|
||||
one_hot_year = _np_one_hot(bucket_id, len(bucket_edges) + 2)
|
||||
return one_hot_year
|
||||
|
||||
|
||||
def _add_one_hot_features_to_batch(batch: Batch) -> Batch:
|
||||
"""Transforms integer features into one-hot features."""
|
||||
nodes = batch.graph.nodes.copy()
|
||||
nodes['one_hot_year'] = _get_one_hot_year_representation(
|
||||
nodes['year'], nodes['one_hot_type'])
|
||||
del nodes['year']
|
||||
|
||||
# NUM_CLASSES plus one category for papers for which a class is not provided
|
||||
# and another for nodes that are not papers.
|
||||
nodes['one_hot_label_as_feature'] = _np_one_hot(
|
||||
nodes['label_as_feature'], NUM_CLASSES + 2)
|
||||
del nodes['label_as_feature']
|
||||
return batch._replace(graph=batch.graph._replace(nodes=nodes))
|
||||
|
||||
|
||||
def _add_embeddings_to_batch(batch: Batch, embeddings: np.ndarray) -> Batch:
|
||||
nodes = batch.graph.nodes.copy()
|
||||
nodes['features'] = embeddings[batch.absolute_node_indices]
|
||||
graph = batch.graph._replace(nodes=nodes)
|
||||
return batch._replace(graph=graph)
|
||||
@@ -0,0 +1,110 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Download data required for training and evaluating models."""
|
||||
|
||||
import pathlib
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
from google.cloud import storage
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
import data_utils
|
||||
|
||||
Path = pathlib.Path
|
||||
|
||||
|
||||
_BUCKET_NAME = 'deepmind-ogb-lsc'
|
||||
_MAX_DOWNLOAD_ATTEMPTS = 5
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_enum('payload', None, ['data', 'models'],
|
||||
'Download "data" or "models"?')
|
||||
flags.DEFINE_string('task_root', None, 'Local task root directory')
|
||||
|
||||
DATA_RELATIVE_PATHS = (
|
||||
data_utils.RAW_NODE_YEAR_FILENAME,
|
||||
data_utils.TRAIN_INDEX_FILENAME,
|
||||
data_utils.VALID_INDEX_FILENAME,
|
||||
data_utils.TEST_INDEX_FILENAME,
|
||||
data_utils.K_FOLD_SPLITS_DIR,
|
||||
data_utils.FUSED_NODE_LABELS_FILENAME,
|
||||
data_utils.FUSED_PAPER_EDGES_FILENAME,
|
||||
data_utils.FUSED_PAPER_EDGES_T_FILENAME,
|
||||
data_utils.EDGES_AUTHOR_INSTITUTION,
|
||||
data_utils.EDGES_INSTITUTION_AUTHOR,
|
||||
data_utils.EDGES_AUTHOR_PAPER,
|
||||
data_utils.EDGES_PAPER_AUTHOR,
|
||||
data_utils.PCA_MERGED_FEATURES_FILENAME,
|
||||
)
|
||||
|
||||
|
||||
class DataCorruptionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _get_gcs_root():
|
||||
return Path('mag') / FLAGS.payload
|
||||
|
||||
|
||||
def _get_gcs_bucket():
|
||||
storage_client = storage.Client.create_anonymous_client()
|
||||
return storage_client.bucket(_BUCKET_NAME)
|
||||
|
||||
|
||||
def _write_blob_to_destination(blob, task_root, ignore_existing=True):
|
||||
"""Write the blob."""
|
||||
logging.info("Copying blob: '%s'", blob.name)
|
||||
destination_path = Path(task_root) / Path(*Path(blob.name).parts[1:])
|
||||
logging.info(" ... to: '%s'", str(destination_path))
|
||||
if ignore_existing and destination_path.exists():
|
||||
return
|
||||
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
checksum = 'crc32c'
|
||||
for attempt in range(_MAX_DOWNLOAD_ATTEMPTS):
|
||||
try:
|
||||
blob.download_to_filename(destination_path.as_posix(), checksum=checksum)
|
||||
except storage.client.resumable_media.common.DataCorruption:
|
||||
pass
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise DataCorruptionError(f"Checksum ('{checksum}') for {blob.name} failed "
|
||||
f'after {attempt + 1} attempts')
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
bucket = _get_gcs_bucket()
|
||||
if FLAGS.payload == 'data':
|
||||
relative_paths = DATA_RELATIVE_PATHS
|
||||
else:
|
||||
relative_paths = (None,)
|
||||
for relative_path in relative_paths:
|
||||
if relative_path is None:
|
||||
relative_path = str(_get_gcs_root())
|
||||
else:
|
||||
relative_path = str(_get_gcs_root() / relative_path)
|
||||
logging.info("Copying relative path: '%s'", relative_path)
|
||||
blobs = bucket.list_blobs(prefix=relative_path)
|
||||
for blob in blobs:
|
||||
_write_blob_to_destination(blob, FLAGS.task_root)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('payload')
|
||||
flags.mark_flag_as_required('task_root')
|
||||
app.run(main)
|
||||
@@ -0,0 +1,227 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Ensemble k-fold predictions and generate final submission file."""
|
||||
|
||||
import collections
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import dill
|
||||
import jax
|
||||
import numpy as np
|
||||
from ogb import lsc
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
import data_utils
|
||||
import losses
|
||||
|
||||
|
||||
_NUM_KFOLD_SPLITS = 10
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
_DATA_ROOT = flags.DEFINE_string('data_root', None, 'Path to the data root')
|
||||
_SPLIT = flags.DEFINE_enum('split', None, ['valid', 'test'], 'Data split')
|
||||
_PREDICTIONS_PATH = flags.DEFINE_string(
|
||||
'predictions_path', None, 'Path with the output of the k-fold models.')
|
||||
_OUTPUT_PATH = flags.DEFINE_string('output_path', None, 'Output path.')
|
||||
|
||||
|
||||
def _np_one_hot(targets: np.ndarray, nb_classes: int):
|
||||
res = np.zeros(targets.shape + (nb_classes,), dtype=np.float32)
|
||||
np.put_along_axis(res, targets.astype(np.int32)[..., None], 1.0, axis=-1)
|
||||
return res
|
||||
|
||||
|
||||
def ensemble_predictions(
|
||||
node_idx_to_logits_list,
|
||||
all_labels,
|
||||
node_indices,
|
||||
use_mode_break_tie_by_mean: bool = True,
|
||||
):
|
||||
"""Ensemble together predictions for each node and generate final predictions."""
|
||||
# First, assert that each node has the same number of predictions to ensemble.
|
||||
num_predictions_per_node = [
|
||||
len(x) for x in node_idx_to_logits_list.values()
|
||||
]
|
||||
num_models = np.unique(num_predictions_per_node)
|
||||
assert num_models.shape[0] == 1
|
||||
num_models = num_models[0]
|
||||
# Gather all logits, shape should be [num_nodes, num_models, num_classes].
|
||||
all_logits = np.stack(
|
||||
[np.stack(node_idx_to_logits_list[idx]) for idx in node_indices])
|
||||
assert all_logits.shape == (node_indices.shape[0], num_models,
|
||||
data_utils.NUM_CLASSES)
|
||||
# Softmax on the final axis.
|
||||
all_probs = jax.nn.softmax(all_logits, axis=-1)
|
||||
# Take average across models axis to get probabilities.
|
||||
mean_probs = np.mean(all_probs, axis=1)
|
||||
|
||||
# Assert there are no 2 equal logits for different classes.
|
||||
max_logit_value = np.max(all_logits, axis=-1)
|
||||
num_classes_with_max_value = (
|
||||
all_logits == max_logit_value[..., None]).sum(axis=-1)
|
||||
|
||||
num_logit_ties = (num_classes_with_max_value > 1).sum()
|
||||
if num_logit_ties:
|
||||
logging.warn(
|
||||
'Found %d models with the exact same logits for two of the classes. '
|
||||
'`argmax` will choose the first.', num_logit_ties)
|
||||
|
||||
# Each model votes on one class per type.
|
||||
all_votes = np.argmax(all_logits, axis=-1)
|
||||
assert all_votes.shape == (node_indices.shape[0], num_models)
|
||||
|
||||
all_votes_one_hot = _np_one_hot(all_votes, data_utils.NUM_CLASSES)
|
||||
assert all_votes_one_hot.shape == (node_indices.shape[0], num_models,
|
||||
data_utils.NUM_CLASSES)
|
||||
|
||||
num_votes_per_class = np.sum(all_votes_one_hot, axis=1)
|
||||
assert num_votes_per_class.shape == (
|
||||
node_indices.shape[0], data_utils.NUM_CLASSES)
|
||||
|
||||
if use_mode_break_tie_by_mean:
|
||||
# Slight hack, give high weight to votes (any number > 1 works really)
|
||||
# and add probabilities between [0, 1] per class to tie-break only within
|
||||
# classes with equal votes.
|
||||
total_score = 10 * num_votes_per_class + mean_probs
|
||||
else:
|
||||
# Just take mean.
|
||||
total_score = mean_probs
|
||||
|
||||
ensembled_logits = np.log(total_score)
|
||||
return losses.Predictions(
|
||||
node_indices=node_indices,
|
||||
labels=all_labels,
|
||||
logits=ensembled_logits,
|
||||
predictions=np.argmax(ensembled_logits, axis=-1),
|
||||
)
|
||||
|
||||
|
||||
def load_predictions(predictions_path, split):
|
||||
"""Loads set of predictions made by given XID."""
|
||||
|
||||
# Generate list of predictions per node.
|
||||
# Note for validation each validation index is only present in exactly 1
|
||||
# model of the k-fold, however for test it is present in all of them.
|
||||
node_idx_to_logits_list = collections.defaultdict(list)
|
||||
|
||||
# For the 10 models in the ensemble.
|
||||
for i in range(_NUM_KFOLD_SPLITS):
|
||||
path = os.path.join(predictions_path, str(i))
|
||||
|
||||
# Find subdirectories.
|
||||
# Directories will be something like:
|
||||
# os.path.join(path, "step_104899_2021-06-14T18:20:05", "(test|valid).dill")
|
||||
# So we make sure there is only one.
|
||||
candidates = []
|
||||
for date_str in os.listdir(path):
|
||||
candidate_path = os.path.join(path, date_str, f'{split}.dill')
|
||||
if os.path.exists(candidate_path):
|
||||
candidates.append(candidate_path)
|
||||
if not candidates:
|
||||
raise ValueError(f'No {split} predictions found at {path}')
|
||||
elif len(candidates) > 1:
|
||||
raise ValueError(f'Found more than one {split} predictions: {candidates}')
|
||||
|
||||
path_for_kth_model_predictions = candidates[0]
|
||||
with open(path_for_kth_model_predictions, 'rb') as f:
|
||||
results = dill.load(f)
|
||||
logging.info('Loaded %s', path_for_kth_model_predictions)
|
||||
for (node_idx, logits) in zip(results.node_indices,
|
||||
results.logits):
|
||||
node_idx_to_logits_list[node_idx].append(logits)
|
||||
|
||||
return node_idx_to_logits_list
|
||||
|
||||
|
||||
def generate_ensembled_predictions(
|
||||
data_root: str, predictions_path: str, split: str) -> losses.Predictions:
|
||||
"""Ensemble checkpoints from all WIDs in XID and generates submission file."""
|
||||
|
||||
array_dict = data_utils.get_arrays(
|
||||
data_root=data_root,
|
||||
return_pca_embeddings=False,
|
||||
return_adjacencies=False)
|
||||
|
||||
# Load all valid and test predictions.
|
||||
node_idx_to_logits_list = load_predictions(predictions_path, split)
|
||||
|
||||
# Assert that the indices loaded are as expected.
|
||||
expected_idx = array_dict[f'{split}_indices']
|
||||
idx_found = np.array(list(node_idx_to_logits_list.keys()))
|
||||
assert np.all(np.sort(idx_found) == expected_idx)
|
||||
|
||||
if split == 'valid':
|
||||
true_labels = array_dict['paper_label'][expected_idx.astype(np.int32)]
|
||||
else:
|
||||
# Don't know the test labels.
|
||||
true_labels = np.full(expected_idx.shape, np.nan)
|
||||
|
||||
# Ensemble together all predictions.
|
||||
return ensemble_predictions(
|
||||
node_idx_to_logits_list, true_labels, expected_idx)
|
||||
|
||||
|
||||
def evaluate_validation(valid_predictions):
|
||||
|
||||
evaluator = lsc.MAG240MEvaluator()
|
||||
|
||||
evaluator_ouput = evaluator.eval(
|
||||
dict(y_pred=valid_predictions.predictions.astype(np.float64),
|
||||
y_true=valid_predictions.labels))
|
||||
logging.info(
|
||||
'Validation accuracy as reported by MAG240MEvaluator: %s',
|
||||
evaluator_ouput)
|
||||
|
||||
|
||||
def save_test_submission_file(test_predictions, output_dir):
|
||||
evaluator = lsc.MAG240MEvaluator()
|
||||
evaluator.save_test_submission(
|
||||
dict(y_pred=test_predictions.predictions.astype(np.float64)), output_dir)
|
||||
logging.info('Test submission file generated at %s', output_dir)
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
|
||||
split = _SPLIT.value
|
||||
|
||||
ensembled_predictions = generate_ensembled_predictions(
|
||||
data_root=_DATA_ROOT.value,
|
||||
predictions_path=_PREDICTIONS_PATH.value,
|
||||
split=split)
|
||||
|
||||
output_dir = _OUTPUT_PATH.value
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if split == 'valid':
|
||||
evaluate_validation(ensembled_predictions)
|
||||
elif split == 'test':
|
||||
save_test_submission_file(ensembled_predictions, output_dir)
|
||||
|
||||
ensembled_predictions_path = os.path.join(output_dir, f'{split}.dill')
|
||||
assert not os.path.exists(ensembled_predictions_path)
|
||||
with open(ensembled_predictions_path, 'wb') as f:
|
||||
dill.dump(ensembled_predictions, f)
|
||||
logging.info(
|
||||
'%s predictions stored at %s', split, ensembled_predictions_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,51 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Generates the k-fold validation splits."""
|
||||
|
||||
import os
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
import data_utils
|
||||
|
||||
|
||||
_DATA_ROOT = flags.DEFINE_string(
|
||||
'data_root', None, required=True,
|
||||
help='Path containing the downloaded data.')
|
||||
|
||||
|
||||
_OUTPUT_DIR = flags.DEFINE_string(
|
||||
'output_dir', None, required=True,
|
||||
help='Output directory to write the splits to')
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
array_dict = data_utils.get_arrays(
|
||||
data_root=_DATA_ROOT.value,
|
||||
return_pca_embeddings=False,
|
||||
return_adjacencies=False)
|
||||
|
||||
os.makedirs(_OUTPUT_DIR.value, exist_ok=True)
|
||||
data_utils.generate_k_fold_splits(
|
||||
train_idx=array_dict['train_indices'],
|
||||
valid_idx=array_dict['valid_indices'],
|
||||
output_path=_OUTPUT_DIR.value,
|
||||
num_splits=data_utils.NUM_K_FOLD_SPLITS)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,199 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Losses and related utilities."""
|
||||
|
||||
from typing import Mapping, Tuple, Sequence, NamedTuple, Dict, Optional
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jraph
|
||||
import numpy as np
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
import datasets
|
||||
|
||||
LogsDict = Mapping[str, jnp.ndarray]
|
||||
|
||||
|
||||
class Predictions(NamedTuple):
|
||||
node_indices: np.ndarray
|
||||
labels: np.ndarray
|
||||
predictions: np.ndarray
|
||||
logits: np.ndarray
|
||||
|
||||
|
||||
def node_classification_loss(
|
||||
logits: jnp.ndarray,
|
||||
batch: datasets.Batch,
|
||||
extra_stats: bool = False,
|
||||
) -> Tuple[jnp.ndarray, LogsDict]:
|
||||
"""Gets node-wise classification loss and statistics."""
|
||||
log_probs = jax.nn.log_softmax(logits)
|
||||
loss = -jnp.sum(log_probs * batch.node_labels, axis=-1)
|
||||
|
||||
num_valid = jnp.sum(batch.label_mask)
|
||||
labels = jnp.argmax(batch.node_labels, axis=-1)
|
||||
is_correct = (jnp.argmax(log_probs, axis=-1) == labels)
|
||||
num_correct = jnp.sum(is_correct * batch.label_mask)
|
||||
loss = jnp.sum(loss * batch.label_mask) / (num_valid + 1e-8)
|
||||
accuracy = num_correct / (num_valid + 1e-8)
|
||||
|
||||
entropy = -jnp.mean(jnp.sum(jax.nn.softmax(logits) * log_probs, axis=-1))
|
||||
|
||||
stats = {
|
||||
'classification_loss': loss,
|
||||
'prediction_entropy': entropy,
|
||||
'accuracy': accuracy,
|
||||
'num_valid': num_valid,
|
||||
'num_correct': num_correct,
|
||||
}
|
||||
if extra_stats:
|
||||
for k in range(1, 6):
|
||||
stats[f'top_{k}_correct'] = topk_correct(logits, labels,
|
||||
batch.label_mask, k)
|
||||
return loss, stats
|
||||
|
||||
|
||||
def get_predictions_labels_and_logits(
|
||||
logits: jnp.ndarray,
|
||||
batch: datasets.Batch,
|
||||
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
|
||||
"""Gets prediction labels and logits."""
|
||||
mask = batch.label_mask > 0.
|
||||
indices = batch.node_indices[mask]
|
||||
logits = logits[mask]
|
||||
predictions = jnp.argmax(logits, axis=-1)
|
||||
labels = jnp.argmax(batch.node_labels[mask], axis=-1)
|
||||
return indices, predictions, labels, logits
|
||||
|
||||
|
||||
def topk_correct(
|
||||
logits: jnp.ndarray,
|
||||
labels: jnp.ndarray,
|
||||
valid_mask: jnp.ndarray,
|
||||
topk: int,
|
||||
) -> jnp.ndarray:
|
||||
"""Calculates top-k accuracy."""
|
||||
pred_ranking = jnp.argsort(logits, axis=1)[:, ::-1]
|
||||
pred_ranking = pred_ranking[:, :topk]
|
||||
is_correct = jnp.any(pred_ranking == labels[:, jnp.newaxis], axis=1)
|
||||
return (is_correct * valid_mask).sum()
|
||||
|
||||
|
||||
def ensemble_predictions_by_probability_average(
|
||||
predictions_list: Sequence[Predictions]) -> Predictions:
|
||||
"""Ensemble predictions by ensembling the probabilities."""
|
||||
_assert_consistent_predictions(predictions_list)
|
||||
all_probs = np.stack([
|
||||
jax.nn.softmax(predictions.logits, axis=-1)
|
||||
for predictions in predictions_list
|
||||
],
|
||||
axis=0)
|
||||
ensembled_logits = np.log(all_probs.mean(0))
|
||||
return predictions_list[0]._replace(
|
||||
logits=ensembled_logits, predictions=np.argmax(ensembled_logits, axis=-1))
|
||||
|
||||
|
||||
def get_accuracy_dict(predictions: Predictions) -> Dict[str, float]:
|
||||
"""Returns the accuracy dict."""
|
||||
output_dict = {}
|
||||
output_dict['num_valid'] = predictions.predictions.shape[0]
|
||||
matches = (predictions.labels == predictions.predictions)
|
||||
output_dict['accuracy'] = matches.mean()
|
||||
|
||||
pred_ranking = jnp.argsort(predictions.logits, axis=1)[:, ::-1]
|
||||
for k in range(1, 6):
|
||||
matches = jnp.any(
|
||||
pred_ranking[:, :k] == predictions.labels[:, None], axis=1)
|
||||
output_dict[f'top_{k}_correct'] = matches.mean()
|
||||
return output_dict
|
||||
|
||||
|
||||
def bgrl_loss(
|
||||
first_online_predictions: jnp.ndarray,
|
||||
second_target_projections: jnp.ndarray,
|
||||
second_online_predictions: jnp.ndarray,
|
||||
first_target_projections: jnp.ndarray,
|
||||
symmetrize: bool,
|
||||
valid_mask: jnp.ndarray,
|
||||
) -> Tuple[jnp.ndarray, LogsDict]:
|
||||
"""Implements BGRL loss."""
|
||||
first_side_node_loss = jnp.sum(
|
||||
jnp.square(
|
||||
_l2_normalize(first_online_predictions, axis=-1) -
|
||||
_l2_normalize(second_target_projections, axis=-1)),
|
||||
axis=-1)
|
||||
if symmetrize:
|
||||
second_side_node_loss = jnp.sum(
|
||||
jnp.square(
|
||||
_l2_normalize(second_online_predictions, axis=-1) -
|
||||
_l2_normalize(first_target_projections, axis=-1)),
|
||||
axis=-1)
|
||||
node_loss = first_side_node_loss + second_side_node_loss
|
||||
else:
|
||||
node_loss = first_side_node_loss
|
||||
loss = (node_loss * valid_mask).sum() / (valid_mask.sum() + 1e-6)
|
||||
return loss, dict(bgrl_loss=loss)
|
||||
|
||||
|
||||
def get_corrupted_view(
|
||||
graph: jraph.GraphsTuple,
|
||||
feature_drop_prob: float,
|
||||
edge_drop_prob: float,
|
||||
rng_key: jnp.ndarray,
|
||||
) -> jraph.GraphsTuple:
|
||||
"""Returns corrupted graph view."""
|
||||
node_key, edge_key = jax.random.split(rng_key)
|
||||
|
||||
def mask_feature(x):
|
||||
mask = jax.random.bernoulli(node_key, 1 - feature_drop_prob, x.shape)
|
||||
return x * mask
|
||||
|
||||
# Randomly mask features with fixed probability.
|
||||
nodes = jax.tree_map(mask_feature, graph.nodes)
|
||||
|
||||
# Simulate dropping of edges by changing genuine edges to self-loops on
|
||||
# the padded node.
|
||||
num_edges = graph.senders.shape[0]
|
||||
last_node_idx = graph.n_node.sum() - 1
|
||||
edge_mask = jax.random.bernoulli(edge_key, 1 - edge_drop_prob, [num_edges])
|
||||
senders = jnp.where(edge_mask, graph.senders, last_node_idx)
|
||||
receivers = jnp.where(edge_mask, graph.receivers, last_node_idx)
|
||||
# Note that n_edge will now be invalid since edges in the middle of the list
|
||||
# will correspond to the final graph. Set n_edge to None to ensure we do not
|
||||
# accidentally use this.
|
||||
return graph._replace(
|
||||
nodes=nodes,
|
||||
senders=senders,
|
||||
receivers=receivers,
|
||||
n_edge=None,
|
||||
)
|
||||
|
||||
|
||||
def _assert_consistent_predictions(predictions_list: Sequence[Predictions]):
|
||||
first_predictions = predictions_list[0]
|
||||
for predictions in predictions_list:
|
||||
assert np.all(predictions.node_indices == first_predictions.node_indices)
|
||||
assert np.all(predictions.labels == first_predictions.labels)
|
||||
assert np.all(
|
||||
predictions.predictions == np.argmax(predictions.logits, axis=-1))
|
||||
|
||||
|
||||
def _l2_normalize(
|
||||
x: jnp.ndarray,
|
||||
axis: Optional[int] = None,
|
||||
epsilon: float = 1e-6,
|
||||
) -> jnp.ndarray:
|
||||
return x * jax.lax.rsqrt(
|
||||
jnp.sum(jnp.square(x), axis=axis, keepdims=True) + epsilon)
|
||||
@@ -0,0 +1,270 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MAG240M-LSC models."""
|
||||
|
||||
from typing import Callable, NamedTuple, Sequence
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jraph
|
||||
|
||||
|
||||
_REDUCER_NAMES = {
|
||||
'sum':
|
||||
jax.ops.segment_sum,
|
||||
'mean':
|
||||
jraph.segment_mean,
|
||||
'softmax':
|
||||
jraph.segment_softmax,
|
||||
}
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
node_embeddings: jnp.ndarray
|
||||
node_embedding_projections: jnp.ndarray
|
||||
node_projection_predictions: jnp.ndarray
|
||||
node_logits: jnp.ndarray
|
||||
|
||||
|
||||
def build_update_fn(
|
||||
name: str,
|
||||
output_sizes: Sequence[int],
|
||||
activation: Callable[[jnp.ndarray], jnp.ndarray],
|
||||
normalization_type: str,
|
||||
is_training: bool,
|
||||
):
|
||||
"""Builds update function."""
|
||||
|
||||
def single_mlp(inner_name: str):
|
||||
"""Creates a single MLP performing the update."""
|
||||
mlp = hk.nets.MLP(
|
||||
output_sizes=output_sizes,
|
||||
name=inner_name,
|
||||
activation=activation)
|
||||
mlp = jraph.concatenated_args(mlp)
|
||||
if normalization_type == 'layer_norm':
|
||||
norm = hk.LayerNorm(
|
||||
axis=-1,
|
||||
create_scale=True,
|
||||
create_offset=True,
|
||||
name=name + '_layer_norm')
|
||||
elif normalization_type == 'batch_norm':
|
||||
batch_norm = hk.BatchNorm(
|
||||
create_scale=True,
|
||||
create_offset=True,
|
||||
decay_rate=0.9,
|
||||
name=f'{inner_name}_batch_norm',
|
||||
cross_replica_axis=None if hk.running_init() else 'i',
|
||||
)
|
||||
norm = lambda x: batch_norm(x, is_training)
|
||||
elif normalization_type == 'none':
|
||||
return mlp
|
||||
else:
|
||||
raise ValueError(f'Unknown normalization type {normalization_type}')
|
||||
return jraph.concatenated_args(hk.Sequential([mlp, norm]))
|
||||
|
||||
return single_mlp(f'{name}_homogeneous')
|
||||
|
||||
|
||||
def build_gn(
|
||||
output_sizes: Sequence[int],
|
||||
activation: Callable[[jnp.ndarray], jnp.ndarray],
|
||||
suffix: str,
|
||||
use_sent_edges: bool,
|
||||
is_training: bool,
|
||||
dropedge_rate: float,
|
||||
normalization_type: str,
|
||||
aggregation_function: str,
|
||||
):
|
||||
"""Builds an InteractionNetwork with MLP update functions."""
|
||||
node_update_fn = build_update_fn(
|
||||
f'node_processor_{suffix}',
|
||||
output_sizes,
|
||||
activation=activation,
|
||||
normalization_type=normalization_type,
|
||||
is_training=is_training,
|
||||
)
|
||||
edge_update_fn = build_update_fn(
|
||||
f'edge_processor_{suffix}',
|
||||
output_sizes,
|
||||
activation=activation,
|
||||
normalization_type=normalization_type,
|
||||
is_training=is_training,
|
||||
)
|
||||
|
||||
def maybe_dropedge(x):
|
||||
"""Dropout on edge messages."""
|
||||
if not is_training:
|
||||
return x
|
||||
return x * hk.dropout(
|
||||
hk.next_rng_key(),
|
||||
dropedge_rate,
|
||||
jnp.ones([x.shape[0], 1]),
|
||||
)
|
||||
|
||||
dropped_edge_update_fn = lambda *args: maybe_dropedge(edge_update_fn(*args))
|
||||
return jraph.InteractionNetwork(
|
||||
update_edge_fn=dropped_edge_update_fn,
|
||||
update_node_fn=node_update_fn,
|
||||
aggregate_edges_for_nodes_fn=_REDUCER_NAMES[aggregation_function],
|
||||
include_sent_messages_in_node_update=use_sent_edges,
|
||||
)
|
||||
|
||||
|
||||
def _get_activation_fn(name: str) -> Callable[[jnp.ndarray], jnp.ndarray]:
|
||||
if name == 'identity':
|
||||
return lambda x: x
|
||||
if hasattr(jax.nn, name):
|
||||
return getattr(jax.nn, name)
|
||||
raise ValueError('Unknown activation function %s specified. '
|
||||
'See https://jax.readthedocs.io/en/latest/jax.nn.html'
|
||||
'for the list of supported function names.')
|
||||
|
||||
|
||||
class NodePropertyEncodeProcessDecode(hk.Module):
|
||||
"""Node Property Prediction Encode Process Decode Model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mlp_hidden_sizes: Sequence[int],
|
||||
latent_size: int,
|
||||
num_classes: int,
|
||||
num_message_passing_steps: int = 2,
|
||||
activation: str = 'relu',
|
||||
dropout_rate: float = 0.0,
|
||||
dropedge_rate: float = 0.0,
|
||||
use_sent_edges: bool = False,
|
||||
disable_edge_updates: bool = False,
|
||||
normalization_type: str = 'layer_norm',
|
||||
aggregation_function: str = 'sum',
|
||||
name='NodePropertyEncodeProcessDecode',
|
||||
):
|
||||
super().__init__(name=name)
|
||||
self._num_classes = num_classes
|
||||
self._latent_size = latent_size
|
||||
self._output_sizes = list(mlp_hidden_sizes) + [latent_size]
|
||||
self._num_message_passing_steps = num_message_passing_steps
|
||||
self._activation = _get_activation_fn(activation)
|
||||
self._dropout_rate = dropout_rate
|
||||
self._dropedge_rate = dropedge_rate
|
||||
self._use_sent_edges = use_sent_edges
|
||||
self._disable_edge_updates = disable_edge_updates
|
||||
self._normalization_type = normalization_type
|
||||
self._aggregation_function = aggregation_function
|
||||
|
||||
def _dropout_graph(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple:
|
||||
node_key, edge_key = hk.next_rng_keys(2)
|
||||
nodes = hk.dropout(node_key, self._dropout_rate, graph.nodes)
|
||||
edges = graph.edges
|
||||
if not self._disable_edge_updates:
|
||||
edges = hk.dropout(edge_key, self._dropout_rate, edges)
|
||||
return graph._replace(nodes=nodes, edges=edges)
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
graph: jraph.GraphsTuple,
|
||||
is_training: bool,
|
||||
) -> jraph.GraphsTuple:
|
||||
node_embed_fn = build_update_fn(
|
||||
'node_encoder',
|
||||
self._output_sizes,
|
||||
activation=self._activation,
|
||||
normalization_type=self._normalization_type,
|
||||
is_training=is_training,
|
||||
)
|
||||
edge_embed_fn = build_update_fn(
|
||||
'edge_encoder',
|
||||
self._output_sizes,
|
||||
activation=self._activation,
|
||||
normalization_type=self._normalization_type,
|
||||
is_training=is_training,
|
||||
)
|
||||
gn = jraph.GraphMapFeatures(edge_embed_fn, node_embed_fn)
|
||||
graph = gn(graph)
|
||||
if is_training:
|
||||
graph = self._dropout_graph(graph)
|
||||
return graph
|
||||
|
||||
def _process(
|
||||
self,
|
||||
graph: jraph.GraphsTuple,
|
||||
is_training: bool,
|
||||
) -> jraph.GraphsTuple:
|
||||
for idx in range(self._num_message_passing_steps):
|
||||
net = build_gn(
|
||||
output_sizes=self._output_sizes,
|
||||
activation=self._activation,
|
||||
suffix=str(idx),
|
||||
use_sent_edges=self._use_sent_edges,
|
||||
is_training=is_training,
|
||||
dropedge_rate=self._dropedge_rate,
|
||||
normalization_type=self._normalization_type,
|
||||
aggregation_function=self._aggregation_function)
|
||||
residual_graph = net(graph)
|
||||
graph = graph._replace(nodes=graph.nodes + residual_graph.nodes)
|
||||
if not self._disable_edge_updates:
|
||||
graph = graph._replace(edges=graph.edges + residual_graph.edges)
|
||||
if is_training:
|
||||
graph = self._dropout_graph(graph)
|
||||
return graph
|
||||
|
||||
def _node_mlp(
|
||||
self,
|
||||
graph: jraph.GraphsTuple,
|
||||
is_training: bool,
|
||||
output_size: int,
|
||||
name: str,
|
||||
) -> jnp.ndarray:
|
||||
decoder_sizes = list(self._output_sizes[:-1]) + [output_size]
|
||||
net = build_update_fn(
|
||||
name,
|
||||
decoder_sizes,
|
||||
self._activation,
|
||||
normalization_type=self._normalization_type,
|
||||
is_training=is_training,
|
||||
)
|
||||
return net(graph.nodes)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
graph: jraph.GraphsTuple,
|
||||
is_training: bool,
|
||||
stop_gradient_embedding_to_logits: bool = False,
|
||||
) -> ModelOutput:
|
||||
# Note that these update configs may need to change if
|
||||
# we switch back to GraphNetwork rather than InteractionNetwork.
|
||||
|
||||
graph = self._encode(graph, is_training)
|
||||
graph = self._process(graph, is_training)
|
||||
node_embeddings = graph.nodes
|
||||
node_projections = self._node_mlp(graph, is_training, self._latent_size,
|
||||
'projector')
|
||||
node_predictions = self._node_mlp(
|
||||
graph._replace(nodes=node_projections),
|
||||
is_training,
|
||||
self._latent_size,
|
||||
'predictor',
|
||||
)
|
||||
if stop_gradient_embedding_to_logits:
|
||||
graph = jax.tree_map(jax.lax.stop_gradient, graph)
|
||||
node_logits = self._node_mlp(graph, is_training, self._num_classes,
|
||||
'logits_decoder')
|
||||
return ModelOutput(
|
||||
node_embeddings=node_embeddings,
|
||||
node_logits=node_logits,
|
||||
node_embedding_projections=node_projections,
|
||||
node_projection_predictions=node_predictions,
|
||||
)
|
||||
@@ -0,0 +1,175 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Find neighborhoods around paper feature embeddings."""
|
||||
|
||||
import pathlib
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import annoy
|
||||
import numpy as np
|
||||
import scipy.sparse as sp
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
import data_utils
|
||||
|
||||
Path = pathlib.Path
|
||||
|
||||
|
||||
_PAPER_PAPER_B_PATH = 'ogb_mag_adjacencies/paper_paper_b.npz'
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('data_root', None, 'Data root directory')
|
||||
|
||||
|
||||
def _read_paper_pca_features():
|
||||
data_root = Path(FLAGS.data_root)
|
||||
path = data_root / data_utils.PCA_PAPER_FEATURES_FILENAME
|
||||
with open(path, 'rb') as fid:
|
||||
return np.load(fid)
|
||||
|
||||
|
||||
def _read_adjacency_indices():
|
||||
# Get adjacencies.
|
||||
return data_utils.get_arrays(
|
||||
data_root=FLAGS.data_root,
|
||||
use_fused_node_labels=False,
|
||||
use_fused_node_adjacencies=False,
|
||||
return_pca_embeddings=False,
|
||||
)
|
||||
|
||||
|
||||
def build_annoy_index(features):
|
||||
"""Build the Annoy index."""
|
||||
logging.info('Building annoy index')
|
||||
num_vectors, vector_size = features.shape
|
||||
annoy_index = annoy.AnnoyIndex(vector_size, 'euclidean')
|
||||
for i, x in enumerate(features):
|
||||
annoy_index.add_item(i, x)
|
||||
if i % 1000000 == 0:
|
||||
logging.info('Adding: %d / %d (%.3g %%)', i, num_vectors,
|
||||
100 * i / num_vectors)
|
||||
n_trees = 10
|
||||
_ = annoy_index.build(n_trees)
|
||||
return annoy_index
|
||||
|
||||
|
||||
def _get_annoy_index_path():
|
||||
return Path(FLAGS.data_root) / data_utils.PREPROCESSED_DIR / 'annoy_index.ann'
|
||||
|
||||
|
||||
def save_annoy_index(annoy_index):
|
||||
logging.info('Saving annoy index')
|
||||
index_path = _get_annoy_index_path()
|
||||
index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
annoy_index.save(str(index_path))
|
||||
|
||||
|
||||
def read_annoy_index(features):
|
||||
index_path = _get_annoy_index_path()
|
||||
vector_size = features.shape[1]
|
||||
annoy_index = annoy.AnnoyIndex(vector_size, 'euclidean')
|
||||
annoy_index.load(str(index_path))
|
||||
return annoy_index
|
||||
|
||||
|
||||
def compute_neighbor_indices_and_distances(features):
|
||||
"""Use the pre-built Annoy index to compute neighbor indices and distances."""
|
||||
logging.info('Computing neighbors and distances')
|
||||
annoy_index = read_annoy_index(features)
|
||||
num_vectors = features.shape[0]
|
||||
|
||||
k = 20
|
||||
pad_k = 5
|
||||
search_k = -1
|
||||
neighbor_indices = np.zeros([num_vectors, k + pad_k + 1], dtype=np.int32)
|
||||
neighbor_distances = np.zeros([num_vectors, k + pad_k + 1], dtype=np.float32)
|
||||
for i in range(num_vectors):
|
||||
neighbor_indices[i], neighbor_distances[i] = annoy_index.get_nns_by_item(
|
||||
i, k + pad_k + 1, search_k=search_k, include_distances=True)
|
||||
if i % 10000 == 0:
|
||||
logging.info('Finding neighbors %d / %d', i, num_vectors)
|
||||
return neighbor_indices, neighbor_distances
|
||||
|
||||
|
||||
def _write_neighbors(neighbor_indices, neighbor_distances):
|
||||
"""Write neighbor indices and distances."""
|
||||
logging.info('Writing neighbors')
|
||||
indices_path = Path(FLAGS.data_root) / data_utils.NEIGHBOR_INDICES_FILENAME
|
||||
distances_path = (
|
||||
Path(FLAGS.data_root) / data_utils.NEIGHBOR_DISTANCES_FILENAME)
|
||||
indices_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
distances_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(indices_path, 'wb') as fid:
|
||||
np.save(fid, neighbor_indices)
|
||||
with open(distances_path, 'wb') as fid:
|
||||
np.save(fid, neighbor_distances)
|
||||
|
||||
|
||||
def _write_fused_edges(fused_paper_adjacency_matrix):
|
||||
"""Write fused edges."""
|
||||
data_root = Path(FLAGS.data_root)
|
||||
edges_path = data_root / data_utils.FUSED_PAPER_EDGES_FILENAME
|
||||
edges_t_path = data_root / data_utils.FUSED_PAPER_EDGES_T_FILENAME
|
||||
edges_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
edges_t_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(edges_path, 'wb') as fid:
|
||||
sp.save_npz(fid, fused_paper_adjacency_matrix)
|
||||
with open(edges_t_path, 'wb') as fid:
|
||||
sp.save_npz(fid, fused_paper_adjacency_matrix.T)
|
||||
|
||||
|
||||
def _write_fused_nodes(fused_node_labels):
|
||||
"""Write fused nodes."""
|
||||
labels_path = Path(FLAGS.data_root) / data_utils.FUSED_NODE_LABELS_FILENAME
|
||||
labels_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(labels_path, 'wb') as fid:
|
||||
np.save(fid, fused_node_labels)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
paper_pca_features = _read_paper_pca_features()
|
||||
# Find neighbors.
|
||||
annoy_index = build_annoy_index(paper_pca_features)
|
||||
save_annoy_index(annoy_index)
|
||||
neighbor_indices, neighbor_distances = compute_neighbor_indices_and_distances(
|
||||
paper_pca_features)
|
||||
del paper_pca_features
|
||||
_write_neighbors(neighbor_indices, neighbor_distances)
|
||||
|
||||
data = _read_adjacency_indices()
|
||||
paper_paper_csr = data['paper_paper_index']
|
||||
paper_label = data['paper_label']
|
||||
train_indices = data['train_indices']
|
||||
valid_indices = data['valid_indices']
|
||||
test_indices = data['test_indices']
|
||||
del data
|
||||
|
||||
fused_paper_adjacency_matrix = data_utils.generate_fused_paper_adjacency_matrix(
|
||||
neighbor_indices, neighbor_distances, paper_paper_csr)
|
||||
_write_fused_edges(fused_paper_adjacency_matrix)
|
||||
del fused_paper_adjacency_matrix
|
||||
del paper_paper_csr
|
||||
|
||||
fused_node_labels = data_utils.generate_fused_node_labels(
|
||||
neighbor_indices, neighbor_distances, paper_label, train_indices,
|
||||
valid_indices, test_indices)
|
||||
_write_fused_nodes(fused_node_labels)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('data_root')
|
||||
app.run(main)
|
||||
@@ -0,0 +1,67 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
while getopts ":i:o:" opt; do
|
||||
case ${opt} in
|
||||
i )
|
||||
INPUT_DIR=$OPTARG
|
||||
;;
|
||||
o )
|
||||
TASK_ROOT=$OPTARG
|
||||
;;
|
||||
\? )
|
||||
echo "Usage: organize_data.sh -i <Downloaded data dir> -o <Task root directory>"
|
||||
;;
|
||||
: )
|
||||
echo "Invalid option: $OPTARG requires an argument" 1>&2
|
||||
;;
|
||||
esac
|
||||
done
|
||||
shift $((OPTIND -1))
|
||||
|
||||
# Get this script's directory.
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
|
||||
if [[ -z "${INPUT_DIR}" ]]; then
|
||||
echo "Need INPUT_DIR argument (-i <INPUT_DIR>)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [[ -z "${TASK_ROOT}" ]]; then
|
||||
echo "Need TASK_ROOT argument (-o <TASK_ROOT>)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATA_ROOT="${TASK_ROOT}"/data
|
||||
|
||||
# Create raw directory to move all files to it.
|
||||
mkdir "${INPUT_DIR}"/mag240m_kddcup2021/raw
|
||||
|
||||
mv "${INPUT_DIR}"/mag240m_kddcup2021/processed/paper/node_feat.npy \
|
||||
"${INPUT_DIR}"/mag240m_kddcup2021/processed/paper/node_label.npy \
|
||||
"${INPUT_DIR}"/mag240m_kddcup2021/processed/paper/node_year.npy \
|
||||
"${DATA_ROOT}"/raw
|
||||
mv "${INPUT_DIR}"/mag240m_kddcup2021/processed/author___affiliated_with___institution/edge_index.npy \
|
||||
"${DATA_ROOT}"/raw/author_affiliated_with_institution_edges.npy
|
||||
mv "${ROOT}"/mag240m_kddcup2021/processed/author___writes___paper/edge_index.npy \
|
||||
"${DATA_ROOT}"/raw/author_writes_paper_edges.npy
|
||||
mv "${ROOT}"/mag240m_kddcup2021/processed/paper___cites___paper/edge_index.npy \
|
||||
"${DATA_ROOT}"/raw/paper_cites_paper_edges.npy
|
||||
|
||||
# Split and save the train/valid/test indices to the raw directory, with names
|
||||
"train_idx.npy", "valid_idx.npy", "test_idx.npy":
|
||||
python3 "${SCRIPT_DIR}"/split_and_save_indices.py --data_root=${DATA_ROOT}
|
||||
@@ -0,0 +1,166 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Apply PCA to the papers' BERT features.
|
||||
|
||||
Compute papers' PCA features.
|
||||
Recompute author and institution features from the paper PCA features.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
import data_utils
|
||||
|
||||
Path = pathlib.Path
|
||||
|
||||
_NUMBER_OF_PAPERS_TO_ESTIMATE_PCA_ON = 1000000 # None indicates all.
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('data_root', None, 'Data root directory')
|
||||
|
||||
|
||||
def _sample_vectors(vectors, num_samples, seed=0):
|
||||
"""Randomly sample some vectors."""
|
||||
rand = np.random.RandomState(seed=seed)
|
||||
indices = rand.choice(vectors.shape[0], size=num_samples, replace=False)
|
||||
return vectors[indices]
|
||||
|
||||
|
||||
def _pca(feat):
|
||||
"""Returns evals (variances), evecs (rows are principal components)."""
|
||||
cov = np.cov(feat.T)
|
||||
_, evals, evecs = np.linalg.svd(cov, full_matrices=True)
|
||||
return evals, evecs
|
||||
|
||||
|
||||
def _read_raw_paper_features():
|
||||
"""Load raw paper features."""
|
||||
path = Path(FLAGS.data_root) / data_utils.RAW_NODE_FEATURES_FILENAME
|
||||
try: # Use mmap if possible.
|
||||
features = np.load(path, mmap_mode='r')
|
||||
except FileNotFoundError:
|
||||
with open(path, 'rb') as fid:
|
||||
features = np.load(fid)
|
||||
return features
|
||||
|
||||
|
||||
def _get_principal_components(features,
|
||||
num_principal_components=129,
|
||||
num_samples=10000,
|
||||
seed=2,
|
||||
dtype='f4'):
|
||||
"""Estimate PCA features."""
|
||||
sample = _sample_vectors(
|
||||
features[:_NUMBER_OF_PAPERS_TO_ESTIMATE_PCA_ON], num_samples, seed=seed)
|
||||
# Compute PCA basis.
|
||||
_, evecs = _pca(sample)
|
||||
return evecs[:num_principal_components].T.astype(dtype)
|
||||
|
||||
|
||||
def _project_features_onto_principal_components(features,
|
||||
principal_components,
|
||||
block_size=1000000):
|
||||
"""Apply PCA iteratively."""
|
||||
num_principal_components = principal_components.shape[1]
|
||||
dtype = principal_components.dtype
|
||||
num_vectors = features.shape[0]
|
||||
num_features = features.shape[0]
|
||||
num_blocks = (num_features - 1) // block_size + 1
|
||||
pca_features = np.empty([num_vectors, num_principal_components], dtype=dtype)
|
||||
# Loop through in blocks.
|
||||
start_time = time.time()
|
||||
for i in range(num_blocks):
|
||||
i_start = i * block_size
|
||||
i_end = (i + 1) * block_size
|
||||
f = np.array(features[i_start:i_end].copy())
|
||||
pca_features[i_start:i_end] = np.dot(f, principal_components).astype(dtype)
|
||||
del f
|
||||
elapsed_time = time.time() - start_time
|
||||
time_left = elapsed_time / (i + 1) * (num_blocks - i - 1)
|
||||
logging.info('Features %d / %d. Elapsed time %.1f. Time left: %.1f', i_end,
|
||||
num_vectors, elapsed_time, time_left)
|
||||
return pca_features
|
||||
|
||||
|
||||
def _read_adjacency_indices():
|
||||
# Get adjacencies.
|
||||
return data_utils.get_arrays(
|
||||
data_root=FLAGS.data_root,
|
||||
use_fused_node_labels=False,
|
||||
use_fused_node_adjacencies=False,
|
||||
return_pca_embeddings=False,
|
||||
)
|
||||
|
||||
|
||||
def _compute_author_pca_features(paper_pca_features, index_arrays):
|
||||
return data_utils.paper_features_to_author_features(
|
||||
index_arrays['author_paper_index'], paper_pca_features)
|
||||
|
||||
|
||||
def _compute_institution_pca_features(author_pca_features, index_arrays):
|
||||
return data_utils.author_features_to_institution_features(
|
||||
index_arrays['institution_author_index'], author_pca_features)
|
||||
|
||||
|
||||
def _write_array(path, array):
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, 'wb') as fid:
|
||||
np.save(fid, array)
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
data_root = Path(FLAGS.data_root)
|
||||
|
||||
raw_paper_features = _read_raw_paper_features()
|
||||
principal_components = _get_principal_components(raw_paper_features)
|
||||
paper_pca_features = _project_features_onto_principal_components(
|
||||
raw_paper_features, principal_components)
|
||||
del raw_paper_features
|
||||
del principal_components
|
||||
|
||||
paper_pca_path = data_root / data_utils.PCA_PAPER_FEATURES_FILENAME
|
||||
author_pca_path = data_root / data_utils.PCA_AUTHOR_FEATURES_FILENAME
|
||||
institution_pca_path = (
|
||||
data_root / data_utils.PCA_INSTITUTION_FEATURES_FILENAME)
|
||||
merged_pca_path = data_root / data_utils.PCA_MERGED_FEATURES_FILENAME
|
||||
_write_array(paper_pca_path, paper_pca_features)
|
||||
|
||||
# Compute author and institution features from paper PCA features.
|
||||
index_arrays = _read_adjacency_indices()
|
||||
author_pca_features = _compute_author_pca_features(paper_pca_features,
|
||||
index_arrays)
|
||||
_write_array(author_pca_path, author_pca_features)
|
||||
|
||||
institution_pca_features = _compute_institution_pca_features(
|
||||
author_pca_features, index_arrays)
|
||||
_write_array(institution_pca_path, institution_pca_features)
|
||||
|
||||
merged_pca_features = np.concatenate(
|
||||
[paper_pca_features, author_pca_features, institution_pca_features],
|
||||
axis=0)
|
||||
del author_pca_features
|
||||
del institution_pca_features
|
||||
_write_array(merged_pca_path, merged_pca_features)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('data_root')
|
||||
app.run(main)
|
||||
@@ -0,0 +1,23 @@
|
||||
wheel
|
||||
absl-py>=0.12.0
|
||||
chex>=0.0.7
|
||||
dill>=0.3.3
|
||||
dm-haiku>=0.0.4
|
||||
# jaxline
|
||||
git+https://github.com/deepmind/jaxline@927695a0b88e480d085590736e2daa36c2f2c68b
|
||||
optax>=0.0.8
|
||||
ml_collections
|
||||
numpy>=1.16.4
|
||||
tensorflow>=2.5.0
|
||||
tensorflow-datasets>=4.3.0
|
||||
dm-tree>=0.1.6
|
||||
ogb>=1.3.1
|
||||
# graph_nets
|
||||
git+git://github.com/deepmind/graph_nets.git@64771dff0d74ca8e77b1f1dcd5a7d26634356d61
|
||||
scipy>=1.2.1
|
||||
jaxlib>=0.1.57+cuda110
|
||||
tqdm>=4.28.1
|
||||
annoy>=1.0.3
|
||||
# jraph
|
||||
git+git://github.com/deepmind/jraph.git@80d2f9e4d82f841a71d56ba44fa9e8781e93ae1f
|
||||
google-cloud-storage
|
||||
Executable
+57
@@ -0,0 +1,57 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
|
||||
while getopts ":r:" opt; do
|
||||
case ${opt} in
|
||||
r )
|
||||
TASK_ROOT=$OPTARG
|
||||
;;
|
||||
\? )
|
||||
echo "Usage: preprocess_data.sh -r <Task root directory>"
|
||||
;;
|
||||
: )
|
||||
echo "Invalid option: $OPTARG requires an argument" 1>&2
|
||||
;;
|
||||
esac
|
||||
done
|
||||
shift $((OPTIND -1))
|
||||
|
||||
# Get this script's directory.
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
|
||||
|
||||
DATA_ROOT="${TASK_ROOT}"/data
|
||||
PREPROCESSED_DIR="${DATA_ROOT}"/preprocessed
|
||||
|
||||
# Create preprocessed directory to move all files to it.
|
||||
mkdir -p "${PREPROCESSED_DIR}"
|
||||
|
||||
# Run the CSR edge builder.
|
||||
python "${SCRIPT_DIR}"/csr_builder.py --data_root="${DATA_ROOT}"
|
||||
|
||||
# Run the PCA feature builder.
|
||||
python "${SCRIPT_DIR}"/pca_builder.py --data_root="${DATA_ROOT}"
|
||||
|
||||
# Run the neighbor-finder/fuser builder.
|
||||
python "${SCRIPT_DIR}"/neighbor_builder.py --data_root="${DATA_ROOT}"
|
||||
|
||||
# Run the validation split generator.
|
||||
python "${SCRIPT_DIR}"/generate_validation_splits.py \
|
||||
--data_root="${DATA_ROOT}" \
|
||||
--output_dir="${DATA_ROOT}/k_fold_splits"
|
||||
Executable
+107
@@ -0,0 +1,107 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
while getopts ":r:" opt; do
|
||||
case ${opt} in
|
||||
r )
|
||||
TASK_ROOT=$OPTARG
|
||||
;;
|
||||
\? )
|
||||
echo "Usage: run_pretrained_eval.sh -r <Task root directory>"
|
||||
;;
|
||||
: )
|
||||
echo "Invalid option: $OPTARG requires an argument" 1>&2
|
||||
;;
|
||||
esac
|
||||
done
|
||||
shift $((OPTIND -1))
|
||||
|
||||
# Get this script's directory.
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
|
||||
echo "
|
||||
Note this script may take several days to run with default parameters on
|
||||
a single machine. Reducing EPOCHS_TO_ENSEMBLE from 50 to < 5 should yield
|
||||
slighly lower validation performance but possibly similar test performance.
|
||||
"
|
||||
|
||||
echo "
|
||||
Pre-requisites (See README):
|
||||
* Python dependencies have been installed.
|
||||
* pre-processed data is available in the task dir.
|
||||
* pre-trained model weights are available in the task dir.
|
||||
"
|
||||
|
||||
read -p "Press enter to continue"
|
||||
|
||||
|
||||
# Can set this to "valid" or "test".
|
||||
# On test the results will be ensembled for 10 models, and a submission file
|
||||
# will be created.
|
||||
# On validation, the results will be "gathered" for 10 models. This is because
|
||||
# each model is trained on the train split + 90% of the validation split,
|
||||
# and only evaluated on the remaining 10%, such that each validation paper
|
||||
# is left out from training in exactly one of the 10 models.
|
||||
SPLIT="test"
|
||||
|
||||
# We used 50 epochs in the submission to sample different subgraphs around
|
||||
# each central node.
|
||||
# For each of the 10 models in the ensemble, a single epoch takes about 1
|
||||
# 10-25 minutes (depending on the GPU) on the test set, and about 2-3 minutes
|
||||
# for the corresponding k-fold split of the validation set.
|
||||
EPOCHS_TO_ENSEMBLE=50
|
||||
|
||||
DATA_ROOT=${TASK_ROOT}/data/
|
||||
MODELS_ROOT=${TASK_ROOT}/models/
|
||||
CHECKPOINT_DIR=${TASK_ROOT}/checkpoints/
|
||||
OUTPUT_DIR=${TASK_ROOT}/predictions/
|
||||
|
||||
# We run two seeds for each model of the k=10 k-fold
|
||||
# first seed group: [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]
|
||||
# second seed group: [110, 111, 112, 113, 114, 115, 116, 117, 118, 119]
|
||||
# Thes are the seeds that was selected based on cross validation for each fold.
|
||||
BEST_SEEDS=(100 111 102 113 104 105 106 107 108 109)
|
||||
for K_FOLD_INDEX in {0..9}; do
|
||||
SEED=${BEST_SEEDS[${K_FOLD_INDEX}]}
|
||||
RESTORE_PATH=${MODELS_ROOT}/k${K_FOLD_INDEX}_seed${SEED}
|
||||
echo "Running k=${K_FOLD_INDEX} on ${SPLIT} split using ${RESTORE_PATH}"
|
||||
# This saves the predictions for the K_FOLD_INDEXd'th model in the k-fold to
|
||||
# "config.experiment_kwargs.config.predictions_dir" for subsequent ensembling
|
||||
python "${SCRIPT_DIR}"/experiment.py \
|
||||
--jaxline_mode="eval" \
|
||||
--config="${SCRIPT_DIR}"/config.py \
|
||||
--config.one_off_evaluate=True \
|
||||
--config.checkpoint_dir=${CHECKPOINT_DIR}/${K_FOLD_INDEX} \
|
||||
--config.restore_path=${RESTORE_PATH} \
|
||||
--config.experiment_kwargs.config.dataset_kwargs.data_root=${DATA_ROOT} \
|
||||
--config.experiment_kwargs.config.dataset_kwargs.k_fold_split_id=${K_FOLD_INDEX} \
|
||||
--config.experiment_kwargs.config.num_eval_iterations_to_ensemble=${EPOCHS_TO_ENSEMBLE} \
|
||||
--config.experiment_kwargs.config.predictions_dir=${OUTPUT_DIR}/${K_FOLD_INDEX} \
|
||||
--config.experiment_kwargs.config.eval.split=${SPLIT}
|
||||
|
||||
done
|
||||
|
||||
|
||||
python "${SCRIPT_DIR}"/ensemble_predictions.py \
|
||||
--split=${SPLIT} \
|
||||
--data_root=${DATA_ROOT} \
|
||||
--predictions_path=${OUTPUT_DIR} \
|
||||
--output_path=${OUTPUT_DIR}
|
||||
|
||||
|
||||
echo "Done"
|
||||
Executable
+115
@@ -0,0 +1,115 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
while getopts ":r:" opt; do
|
||||
case ${opt} in
|
||||
r )
|
||||
TASK_ROOT=$OPTARG
|
||||
;;
|
||||
\? )
|
||||
echo "Usage: run_training.sh -r <Task root directory>"
|
||||
;;
|
||||
: )
|
||||
echo "Invalid option: $OPTARG requires an argument" 1>&2
|
||||
;;
|
||||
esac
|
||||
done
|
||||
shift $((OPTIND -1))
|
||||
|
||||
# Get this script's directory.
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
|
||||
echo "
|
||||
These scripts are provided for illustrative purposes. It is not practical for
|
||||
actual training since it only uses a single machine, and likely requires
|
||||
reducing the batch size and/or model size to fit on a single GPU.
|
||||
|
||||
For the actual submission we used training distributed in different ways:
|
||||
* We used 4x Cloud TPU v4s to implement batch parallelism, so batch
|
||||
size is effectively 8 times larger than the values in the config.
|
||||
* We used separate CPUs to subsample neighborhoods around each central node, to
|
||||
maximize TPU usage by always having a batch ready for the next iteration.
|
||||
* We ran the online early-stopping evaluator on a separate machine with an
|
||||
NVIDIA V100 GPU.
|
||||
* We run identical replicas of all of the above for each of the 20 models
|
||||
trained.
|
||||
|
||||
Using those mechanisms, training runs at ~2 steps per second, reaching 500k
|
||||
steps in under 3 days.
|
||||
"
|
||||
|
||||
echo "
|
||||
Pre-requisites (See README):
|
||||
* Python dependencies have been installed.
|
||||
* pre-processed data is available in the task dir.
|
||||
"
|
||||
|
||||
read -p "Press enter to continue"
|
||||
|
||||
# During early stopping seed/selection we ensembled 5 iterations.
|
||||
EPOCHS_TO_ENSEMBLE=5
|
||||
|
||||
DATA_ROOT=${TASK_ROOT}/data/
|
||||
CHECKPOINT_DIR=${TASK_ROOT}/checkpoints/
|
||||
|
||||
# We run two seeds for each model of the k=10 k-fold.
|
||||
BASE_SEED=100
|
||||
for SEED_OFFSET in 0 10; do
|
||||
for K_FOLD_INDEX in {0..9}; do
|
||||
MODEL_SEED=`expr ${BASE_SEED} + ${SEED_OFFSET} + ${K_FOLD_INDEX}`
|
||||
SUFFIX=k${K_FOLD_INDEX}_seed${MODEL_SEED}
|
||||
echo "Running k=${K_FOLD_INDEX} with init seed ${MODEL_SEED}"
|
||||
|
||||
# This runs training (each model is trained on train split + 90% of the
|
||||
# validation split) with early stopping, storing both "latest" model and
|
||||
# "best" early-stopped model at `--config.checkpoint_dir`.
|
||||
|
||||
# Models are early stopped based on accuracy of on 10% of the validation data
|
||||
# (each K_FOLD_INDEX leaves a different 10% of data out from training) left
|
||||
# out from training.
|
||||
|
||||
# Models are stored at the end of training. Intermediate models can also be
|
||||
# stored while training by sending a SIGINT signal (Ctrl+C) which will not
|
||||
# interrupt the training.
|
||||
|
||||
# It is possible to interrupt training using (Ctrl+\) and then continue
|
||||
# passing the corresponding `--config.restore_path=${RESTORE_PATH}` which is
|
||||
# stored when interrupting.
|
||||
|
||||
python "${SCRIPT_DIR}"/experiment.py \
|
||||
--jaxline_mode="train_eval_multithreaded" \
|
||||
--config="${SCRIPT_DIR}"/config.py \
|
||||
--config.random_seed=${MODEL_SEED} \
|
||||
--config.checkpoint_dir=${CHECKPOINT_DIR}/${SUFFIX} \
|
||||
--config.experiment_kwargs.config.dataset_kwargs.data_root=${DATA_ROOT} \
|
||||
--config.experiment_kwargs.config.dataset_kwargs.k_fold_split_id=${K_FOLD_INDEX} \
|
||||
--config.experiment_kwargs.config.num_eval_iterations_to_ensemble=${EPOCHS_TO_ENSEMBLE} \
|
||||
--config.experiment_kwargs.config.eval.split="valid"
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
# Each of the 20 (two for each value of k) jobs generate paths of the form:
|
||||
# RESTORE_PATH=${--config.checkpoint_dir}/models/best/step_${STEP}_${TIMESTAMP}
|
||||
# From each pair, the best one is selected (based on the validation accuracy
|
||||
# as reported on the logs or in tensorboard events also stored at
|
||||
# `${--config.checkpoint_dir}/eval`). These can then be used as the
|
||||
# "RESTORE_PATHS" for `./run_pretrained_eval.sh`.
|
||||
|
||||
|
||||
echo "Done"
|
||||
@@ -0,0 +1,74 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Scheduling utilities."""
|
||||
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
def apply_ema_decay(
|
||||
ema_value: jnp.ndarray,
|
||||
current_value: jnp.ndarray,
|
||||
decay: jnp.ndarray,
|
||||
) -> jnp.ndarray:
|
||||
"""Implements EMA."""
|
||||
return ema_value * decay + current_value * (1 - decay)
|
||||
|
||||
|
||||
def ema_decay_schedule(
|
||||
base_rate: jnp.ndarray,
|
||||
step: jnp.ndarray,
|
||||
total_steps: jnp.ndarray,
|
||||
use_schedule: bool,
|
||||
) -> jnp.ndarray:
|
||||
"""Anneals decay rate to 1 with cosine schedule."""
|
||||
if not use_schedule:
|
||||
return base_rate
|
||||
multiplier = _cosine_decay(step, total_steps, 1.)
|
||||
return 1. - (1. - base_rate) * multiplier
|
||||
|
||||
|
||||
def _cosine_decay(
|
||||
global_step: jnp.ndarray,
|
||||
max_steps: int,
|
||||
initial_value: float,
|
||||
) -> jnp.ndarray:
|
||||
"""Simple implementation of cosine decay from TF1."""
|
||||
global_step = jnp.minimum(global_step, max_steps).astype(jnp.float32)
|
||||
cosine_decay_value = 0.5 * (1 + jnp.cos(jnp.pi * global_step / max_steps))
|
||||
decayed_learning_rate = initial_value * cosine_decay_value
|
||||
return decayed_learning_rate
|
||||
|
||||
|
||||
def learning_schedule(
|
||||
global_step: jnp.ndarray,
|
||||
base_learning_rate: float,
|
||||
total_steps: int,
|
||||
warmup_steps: int,
|
||||
use_schedule: bool,
|
||||
) -> float:
|
||||
"""Cosine learning rate scheduler."""
|
||||
# Compute LR & Scaled LR
|
||||
if not use_schedule:
|
||||
return base_learning_rate
|
||||
warmup_learning_rate = (
|
||||
global_step.astype(jnp.float32) / int(warmup_steps) *
|
||||
base_learning_rate if warmup_steps > 0 else base_learning_rate)
|
||||
|
||||
# Cosine schedule after warmup.
|
||||
decay_learning_rate = _cosine_decay(global_step - warmup_steps,
|
||||
total_steps - warmup_steps,
|
||||
base_learning_rate)
|
||||
return jnp.where(global_step < warmup_steps, warmup_learning_rate,
|
||||
decay_learning_rate)
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Split and save the train/valid/test indices.
|
||||
|
||||
Usage:
|
||||
|
||||
python3 split_and_save_indices.py --data_root="mag_data"
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
Path = pathlib.Path
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('data_root', None, 'Data root directory')
|
||||
|
||||
|
||||
def main(argv) -> None:
|
||||
if len(argv) > 1:
|
||||
raise app.UsageError('Too many command-line arguments.')
|
||||
mag_directory = Path(FLAGS.data_root) / 'mag240m_kddcup2021'
|
||||
raw_directory = mag_directory / 'raw'
|
||||
raw_directory.parent.mkdir(parents=True, exist_ok=True)
|
||||
splits_dict = torch.load(str(mag_directory / 'split_dict.pt'))
|
||||
for key, indices in splits_dict.items():
|
||||
np.save(str(raw_directory / f'{key}_idx.npy'), indices)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('root')
|
||||
app.run(main)
|
||||
@@ -0,0 +1,250 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for subsampling the MAG dataset."""
|
||||
|
||||
import collections
|
||||
|
||||
import jraph
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_or_sample_row(node_id: int,
|
||||
nb_neighbours: int,
|
||||
csr_matrix, remove_duplicates: bool):
|
||||
"""Either obtain entire row or a subsampled set of neighbours."""
|
||||
if node_id + 1 >= csr_matrix.indptr.shape[0]:
|
||||
lo = 0
|
||||
hi = 0
|
||||
else:
|
||||
lo = csr_matrix.indptr[node_id]
|
||||
hi = csr_matrix.indptr[node_id + 1]
|
||||
if lo == hi: # Skip empty neighbourhoods
|
||||
neighbours = None
|
||||
elif hi - lo <= nb_neighbours:
|
||||
neighbours = csr_matrix.indices[lo:hi]
|
||||
elif hi - lo < 5 * nb_neighbours: # For small surroundings, sample directly
|
||||
nb_neighbours = min(nb_neighbours, hi - lo)
|
||||
inds = lo + np.random.choice(hi - lo, size=(nb_neighbours,), replace=False)
|
||||
neighbours = csr_matrix.indices[inds]
|
||||
else: # Otherwise, do not slice -- sample indices instead
|
||||
# To extend GraphSAGE ("uniform w/ replacement"), modify this call
|
||||
inds = np.random.randint(lo, hi, size=(nb_neighbours,))
|
||||
if remove_duplicates:
|
||||
inds = np.unique(inds)
|
||||
neighbours = csr_matrix.indices[inds]
|
||||
return neighbours
|
||||
|
||||
|
||||
def get_neighbours(node_id: int,
|
||||
node_type: int,
|
||||
neighbour_type: int,
|
||||
nb_neighbours: int,
|
||||
remove_duplicates: bool,
|
||||
author_institution_csr, institution_author_csr,
|
||||
author_paper_csr, paper_author_csr,
|
||||
paper_paper_csr, paper_paper_transpose_csr):
|
||||
"""Fetch the edge indices from one node to corresponding neighbour type."""
|
||||
if node_type == 0 and neighbour_type == 0:
|
||||
csr = paper_paper_transpose_csr # Citing
|
||||
elif node_type == 0 and neighbour_type == 1:
|
||||
csr = paper_author_csr
|
||||
elif node_type == 0 and neighbour_type == 3:
|
||||
csr = paper_paper_csr # Cited
|
||||
elif node_type == 1 and neighbour_type == 0:
|
||||
csr = author_paper_csr
|
||||
elif node_type == 1 and neighbour_type == 2:
|
||||
csr = author_institution_csr
|
||||
elif node_type == 2 and neighbour_type == 1:
|
||||
csr = institution_author_csr
|
||||
else:
|
||||
raise ValueError('Non-existent edge type requested')
|
||||
return get_or_sample_row(node_id, nb_neighbours, csr, remove_duplicates)
|
||||
|
||||
|
||||
def get_senders(neighbour_type: int,
|
||||
sender_index,
|
||||
paper_features):
|
||||
"""Get the sender features from given neighbours."""
|
||||
if neighbour_type == 0 or neighbour_type == 3:
|
||||
sender_features = paper_features[sender_index]
|
||||
elif neighbour_type == 1 or neighbour_type == 2:
|
||||
sender_features = np.zeros((sender_index.shape[0],
|
||||
paper_features.shape[1])) # Consider averages
|
||||
else:
|
||||
raise ValueError('Non-existent node type requested')
|
||||
return sender_features
|
||||
|
||||
|
||||
def make_edge_type_feature(node_type: int, neighbour_type: int):
|
||||
edge_feats = np.zeros(7)
|
||||
edge_feats[node_type] = 1.0
|
||||
edge_feats[neighbour_type + 3] = 1.0
|
||||
return edge_feats
|
||||
|
||||
|
||||
def subsample_graph(paper_id: int,
|
||||
author_institution_csr,
|
||||
institution_author_csr,
|
||||
author_paper_csr,
|
||||
paper_author_csr,
|
||||
paper_paper_csr,
|
||||
paper_paper_transpose_csr,
|
||||
max_nb_neighbours_per_type,
|
||||
max_nodes=None,
|
||||
max_edges=None,
|
||||
paper_years=None,
|
||||
remove_future_nodes=False,
|
||||
deduplicate_nodes=False) -> jraph.GraphsTuple:
|
||||
"""Subsample a graph around given paper ID."""
|
||||
if paper_years is not None:
|
||||
root_paper_year = paper_years[paper_id]
|
||||
else:
|
||||
root_paper_year = None
|
||||
# Add the center node as "node-zero"
|
||||
sub_nodes = [paper_id]
|
||||
num_nodes_in_subgraph = 1
|
||||
num_edges_in_subgraph = 0
|
||||
reached_node_budget = False
|
||||
reached_edge_budget = False
|
||||
node_and_type_to_index_in_subgraph = dict()
|
||||
node_and_type_to_index_in_subgraph[(paper_id, 0)] = 0
|
||||
# Store all (integer) depths as an additional feature
|
||||
depths = [0]
|
||||
types = [0]
|
||||
sub_edges = []
|
||||
sub_senders = []
|
||||
sub_receivers = []
|
||||
|
||||
# Store all unprocessed neighbours
|
||||
# Each neighbour is stored as a 4-tuple (node_index in original graph,
|
||||
# node_index in subsampled graph, type, number of hops away from source).
|
||||
# TYPES: 0: paper, 1: author, 2: institution, 3: paper (for bidirectional)
|
||||
neighbour_deque = collections.deque([(paper_id, 0, 0, 0)])
|
||||
|
||||
max_depth = len(max_nb_neighbours_per_type)
|
||||
|
||||
while neighbour_deque and not reached_edge_budget:
|
||||
left_entry = neighbour_deque.popleft()
|
||||
node_index, node_index_in_sampled_graph, node_type, node_depth = left_entry
|
||||
|
||||
# Expand from this node, to a node of related type
|
||||
for neighbour_type in range(4):
|
||||
if reached_edge_budget:
|
||||
break # Budget may have been reached in previous type; break here.
|
||||
nb_neighbours = max_nb_neighbours_per_type[node_depth][node_type][neighbour_type] # pylint:disable=line-too-long
|
||||
# Only extend if we want to sample further in this edge type
|
||||
if nb_neighbours > 0:
|
||||
sampled_neighbors = get_neighbours(
|
||||
node_index,
|
||||
node_type,
|
||||
neighbour_type,
|
||||
nb_neighbours,
|
||||
deduplicate_nodes,
|
||||
author_institution_csr,
|
||||
institution_author_csr,
|
||||
author_paper_csr,
|
||||
paper_author_csr,
|
||||
paper_paper_csr,
|
||||
paper_paper_transpose_csr,
|
||||
)
|
||||
|
||||
if sampled_neighbors is not None:
|
||||
if remove_future_nodes and root_paper_year is not None:
|
||||
if neighbour_type in [0, 3]:
|
||||
sampled_neighbors = [
|
||||
x for x in sampled_neighbors
|
||||
if paper_years[x] <= root_paper_year
|
||||
]
|
||||
if not sampled_neighbors:
|
||||
continue
|
||||
|
||||
nb_neighbours = len(sampled_neighbors)
|
||||
edge_feature = make_edge_type_feature(node_type, neighbour_type)
|
||||
|
||||
for neighbor_original_idx in sampled_neighbors:
|
||||
# Key into dict of existing nodes using both node id and type.
|
||||
neighbor_key = (neighbor_original_idx, neighbour_type % 3)
|
||||
# Get existing idx in subgraph if it exists.
|
||||
neighbor_subgraph_idx = node_and_type_to_index_in_subgraph.get(
|
||||
neighbor_key, None)
|
||||
if (not reached_node_budget and
|
||||
(not deduplicate_nodes or neighbor_subgraph_idx is None)):
|
||||
# If it does not exist already, or we are not deduplicating,
|
||||
# just create a new node and update the dict.
|
||||
neighbor_subgraph_idx = num_nodes_in_subgraph
|
||||
node_and_type_to_index_in_subgraph[neighbor_key] = (
|
||||
neighbor_subgraph_idx)
|
||||
num_nodes_in_subgraph += 1
|
||||
sub_nodes.append(neighbor_original_idx)
|
||||
types.append(neighbour_type % 3)
|
||||
depths.append(node_depth + 1)
|
||||
if max_nodes is not None and num_nodes_in_subgraph >= max_nodes:
|
||||
reached_node_budget = True
|
||||
continue # Move to next neighbor which might already exist.
|
||||
if node_depth < max_depth - 1:
|
||||
# If the neighbours are to be further expanded, enqueue them.
|
||||
# Expand only if the nodes did not already exist.
|
||||
neighbour_deque.append(
|
||||
(neighbor_original_idx, neighbor_subgraph_idx,
|
||||
neighbour_type % 3, node_depth + 1))
|
||||
# The neighbor id within graph is now fixed; just add edges.
|
||||
if neighbor_subgraph_idx is not None:
|
||||
# Either node existed before or was successfully added.
|
||||
sub_senders.append(neighbor_subgraph_idx)
|
||||
sub_receivers.append(node_index_in_sampled_graph)
|
||||
sub_edges.append(edge_feature)
|
||||
num_edges_in_subgraph += 1
|
||||
if max_edges is not None and num_edges_in_subgraph >= max_edges:
|
||||
reached_edge_budget = True
|
||||
break # Break out of adding edges for this neighbor type
|
||||
|
||||
# Stitch the graph together
|
||||
sub_nodes = np.array(sub_nodes, dtype=np.int32)
|
||||
|
||||
if sub_senders:
|
||||
sub_senders = np.array(sub_senders, dtype=np.int32)
|
||||
sub_receivers = np.array(sub_receivers, dtype=np.int32)
|
||||
sub_edges = np.stack(sub_edges, axis=0)
|
||||
else:
|
||||
# Use empty arrays.
|
||||
sub_senders = np.zeros([0], dtype=np.int32)
|
||||
sub_receivers = np.zeros([0], dtype=np.int32)
|
||||
sub_edges = np.zeros([0, 7])
|
||||
|
||||
# Finally, derive the sizes
|
||||
sub_n_node = np.array([sub_nodes.shape[0]])
|
||||
sub_n_edge = np.array([sub_senders.shape[0]])
|
||||
assert sub_nodes.shape[0] == num_nodes_in_subgraph
|
||||
assert sub_edges.shape[0] == num_edges_in_subgraph
|
||||
if max_nodes is not None:
|
||||
assert num_nodes_in_subgraph <= max_nodes
|
||||
if max_edges is not None:
|
||||
assert num_edges_in_subgraph <= max_edges
|
||||
|
||||
types = np.array(types)
|
||||
depths = np.array(depths)
|
||||
sub_nodes = {
|
||||
'index': sub_nodes.astype(np.int32),
|
||||
'type': types.astype(np.int16),
|
||||
'depth': depths.astype(np.int16),
|
||||
}
|
||||
|
||||
return jraph.GraphsTuple(nodes=sub_nodes,
|
||||
edges=sub_edges.astype(np.float16),
|
||||
senders=sub_senders.astype(np.int32),
|
||||
receivers=sub_receivers.astype(np.int32),
|
||||
globals=np.array([0], dtype=np.int16),
|
||||
n_node=sub_n_node.astype(dtype=np.int32),
|
||||
n_edge=sub_n_edge.astype(dtype=np.int32))
|
||||
@@ -0,0 +1,110 @@
|
||||
# DeepMind entry for PCQM4M-LSC
|
||||
|
||||
This repository contains DeepMind's entry to the [PCWM4M-LSC](https://ogb.stanford.edu/kddcup2021/pcqm4m/) (quantum chemistry)
|
||||
track of the [OGB Large-Scale Challenge](https://ogb.stanford.edu/kddcup2021/)
|
||||
(OGB-LSC).
|
||||
|
||||
For full details regarding this entry, please see our [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf).
|
||||
|
||||
## DeepMind PCQ Team ("Quantum")
|
||||
|
||||
(in alphabetical order)
|
||||
|
||||
- Ravichandra Addanki
|
||||
- Peter Battaglia
|
||||
- David Budden
|
||||
- Andreea Deac
|
||||
- Jonathan Godwin
|
||||
- Alvaro Sanchez-Gonzalez
|
||||
- Wai Lok Sibon Li
|
||||
- Jacklynn Stott
|
||||
- Shantanu Thakoor
|
||||
- Petar Veličković
|
||||
|
||||
|
||||
## Performance
|
||||
|
||||
Our final test set performance was achieved by pooling an ensemble of 20 models
|
||||
(10 folds x 2 seeds). See [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf) for details.
|
||||
|
||||
Each model was trained for < 48 hours using 4x Google Cloud TPUv4 and 1x AMD
|
||||
EPYC 7B12 64-core CPU @2.25GHz.
|
||||
|
||||
Inference takes < 6 hours on 1x NVIDIA V100 16GB GPU and 1x Intel Xeon Gold 6148
|
||||
20-core CPU @2.40GHz.
|
||||
|
||||
# Running our model
|
||||
|
||||
## Setup
|
||||
|
||||
You can set up Python virtual environment (you might need to install the
|
||||
`python3-venv` package first) with all needed dependencies inside the forked
|
||||
`deepmind_research` repository using:
|
||||
|
||||
```bash
|
||||
python3 -m venv /tmp/pcq_venv
|
||||
source /tmp/pcq_venv/bin/activate
|
||||
pip3 install --upgrade pip setuptools wheel
|
||||
pip3 install -r ogb_lsc/pcq/requirements.txt
|
||||
```
|
||||
|
||||
## Downloading data and model weights
|
||||
|
||||
All necessary data and pre-trained model weights can be downloaded by running
|
||||
the following command.
|
||||
This downloads about ~ 150 GB worth of model checkpoints.
|
||||
|
||||
```bash
|
||||
python download_required_pcq_data.py --data_root=${HOME}/data/
|
||||
```
|
||||
|
||||
## Generating Pre-processed features
|
||||
|
||||
All the additional features used in training
|
||||
(k-fold splits and conformer position features) can be generated by running.
|
||||
```bash
|
||||
/bin/bash run_preprocessing.sh -r ${HOME}/data/pcq/
|
||||
```
|
||||
|
||||
## Reproducing our final results
|
||||
|
||||
We have provided pre-trained weights of our final submission for convenience. To
|
||||
reproduce our final results, please run `run_pretrained_eval.sh` as follows.
|
||||
|
||||
```bash
|
||||
/bin/bash run_pretrained_eval.sh -r ${HOME}/data/pcq/
|
||||
```
|
||||
|
||||
## Retraining our model
|
||||
|
||||
Disclaimer: This script is provided for illustrative purposes. It is not
|
||||
practical for actual training since it only uses a single machine, and likely
|
||||
requires reducing the batch size and/or model size to fit on a single GPU.
|
||||
|
||||
If you still want to train a model, please run `run_training.sh`. To simply
|
||||
validate that the code is running correctly on your hardware setup, consider
|
||||
setting `debug=True` in `config.py`, which trains a smaller model.
|
||||
|
||||
|
||||
```bash
|
||||
/bin/bash run_training.sh -r ${HOME}/data/pcq/
|
||||
```
|
||||
|
||||
|
||||
# Citation
|
||||
|
||||
To cite this work (together with our MAG240M-LSC entry):
|
||||
|
||||
```latex
|
||||
@article{deepmind2021ogb,
|
||||
author = {Ravichandra Addanki and Peter Battaglia and David Budden and Andreea
|
||||
Deac and Jonathan Godwin and Thomas Keck and Wai Lok Sibon Li and Alvaro
|
||||
Sanchez-Gonzalez and Jacklynn Stott and Shantanu Thakoor and Petar
|
||||
Veli\v{c}kovi\'{c}},
|
||||
title = {Large-scale graph representation learning with very deep GNNs and
|
||||
self-supervision},
|
||||
year = {2021},
|
||||
}
|
||||
```
|
||||
|
||||
Our technical report can be found [here](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf).
|
||||
@@ -0,0 +1,135 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dynamic batching utilities."""
|
||||
|
||||
from typing import Generator, Iterator, Sequence, Tuple
|
||||
|
||||
import jax.tree_util as tree
|
||||
import jraph
|
||||
import numpy as np
|
||||
|
||||
_NUMBER_FIELDS = ("n_node", "n_edge", "n_graph")
|
||||
|
||||
|
||||
def dynamically_batch(graphs_tuple_iterator: Iterator[jraph.GraphsTuple],
|
||||
n_node: int, n_edge: int,
|
||||
n_graph: int) -> Generator[jraph.GraphsTuple, None, None]:
|
||||
"""Dynamically batches trees with `jraph.GraphsTuples` to `graph_batch_size`.
|
||||
|
||||
Elements of the `graphs_tuple_iterator` will be incrementally added to a batch
|
||||
until the limits defined by `n_node`, `n_edge` and `n_graph` are reached. This
|
||||
means each element yielded by this generator
|
||||
|
||||
For situations where you have variable sized data, it"s useful to be able to
|
||||
have variable sized batches. This is especially the case if you have a loss
|
||||
defined on the variable shaped element (for example, nodes in a graph).
|
||||
|
||||
Args:
|
||||
graphs_tuple_iterator: An iterator of `jraph.GraphsTuples`.
|
||||
n_node: The maximum number of nodes in a batch.
|
||||
n_edge: The maximum number of edges in a batch.
|
||||
n_graph: The maximum number of graphs in a batch.
|
||||
|
||||
Yields:
|
||||
A `jraph.GraphsTuple` batch of graphs.
|
||||
|
||||
Raises:
|
||||
ValueError: if the number of graphs is < 2.
|
||||
RuntimeError: if the `graphs_tuple_iterator` contains elements which are not
|
||||
`jraph.GraphsTuple`s.
|
||||
RuntimeError: if a graph is found which is larger than the batch size.
|
||||
"""
|
||||
if n_graph < 2:
|
||||
raise ValueError("The number of graphs in a batch size must be greater or "
|
||||
f"equal to `2` for padding with graphs, got {n_graph}.")
|
||||
valid_batch_size = (n_node - 1, n_edge, n_graph - 1)
|
||||
accumulated_graphs = []
|
||||
num_accumulated_nodes = 0
|
||||
num_accumulated_edges = 0
|
||||
num_accumulated_graphs = 0
|
||||
for element in graphs_tuple_iterator:
|
||||
element_nodes, element_edges, element_graphs = _get_graph_size(element)
|
||||
if _is_over_batch_size(element, valid_batch_size):
|
||||
graph_size = element_nodes, element_edges, element_graphs
|
||||
graph_size = {k: v for k, v in zip(_NUMBER_FIELDS, graph_size)}
|
||||
batch_size = {k: v for k, v in zip(_NUMBER_FIELDS, valid_batch_size)}
|
||||
raise RuntimeError("Found graph bigger than batch size. Valid Batch "
|
||||
f"Size: {batch_size}, Graph Size: {graph_size}")
|
||||
|
||||
if not accumulated_graphs:
|
||||
# If this is the first element of the batch, set it and continue.
|
||||
accumulated_graphs = [element]
|
||||
num_accumulated_nodes = element_nodes
|
||||
num_accumulated_edges = element_edges
|
||||
num_accumulated_graphs = element_graphs
|
||||
continue
|
||||
else:
|
||||
# Otherwise check if there is space for the graph in the batch:
|
||||
if ((num_accumulated_graphs + element_graphs > n_graph - 1) or
|
||||
(num_accumulated_nodes + element_nodes > n_node - 1) or
|
||||
(num_accumulated_edges + element_edges > n_edge)):
|
||||
# If there is, add it to the batch
|
||||
batched_graph = _batch_np(accumulated_graphs)
|
||||
yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph)
|
||||
accumulated_graphs = [element]
|
||||
num_accumulated_nodes = element_nodes
|
||||
num_accumulated_edges = element_edges
|
||||
num_accumulated_graphs = element_graphs
|
||||
else:
|
||||
# Otherwise, return the old batch and start a new batch.
|
||||
accumulated_graphs.append(element)
|
||||
num_accumulated_nodes += element_nodes
|
||||
num_accumulated_edges += element_edges
|
||||
num_accumulated_graphs += element_graphs
|
||||
|
||||
# We may still have data in batched graph.
|
||||
if accumulated_graphs:
|
||||
batched_graph = _batch_np(accumulated_graphs)
|
||||
yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph)
|
||||
|
||||
|
||||
def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple:
|
||||
# Calculates offsets for sender and receiver arrays, caused by concatenating
|
||||
# the nodes arrays.
|
||||
offsets = np.cumsum(np.array([0] + [np.sum(g.n_node) for g in graphs[:-1]]))
|
||||
|
||||
def _map_concat(nests):
|
||||
concat = lambda *args: np.concatenate(args)
|
||||
return tree.tree_multimap(concat, *nests)
|
||||
|
||||
return jraph.GraphsTuple(
|
||||
n_node=np.concatenate([g.n_node for g in graphs]),
|
||||
n_edge=np.concatenate([g.n_edge for g in graphs]),
|
||||
nodes=_map_concat([g.nodes for g in graphs]),
|
||||
edges=_map_concat([g.edges for g in graphs]),
|
||||
globals=_map_concat([g.globals for g in graphs]),
|
||||
senders=np.concatenate([g.senders + o for g, o in zip(graphs, offsets)]),
|
||||
receivers=np.concatenate(
|
||||
[g.receivers + o for g, o in zip(graphs, offsets)]))
|
||||
|
||||
|
||||
def _get_graph_size(graph: jraph.GraphsTuple) -> Tuple[int, int, int]:
|
||||
n_node = np.sum(graph.n_node)
|
||||
n_edge = len(graph.senders)
|
||||
n_graph = len(graph.n_node)
|
||||
return n_node, n_edge, n_graph
|
||||
|
||||
|
||||
def _is_over_batch_size(
|
||||
graph: jraph.GraphsTuple,
|
||||
graph_batch_size: Tuple[int, int, int],
|
||||
) -> bool:
|
||||
graph_size = _get_graph_size(graph)
|
||||
return any([x > y for x, y in zip(graph_size, graph_batch_size)])
|
||||
@@ -0,0 +1,130 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Experiment config for PCQM4M-LSC entry."""
|
||||
|
||||
from jaxline import base_config
|
||||
from ml_collections import config_dict
|
||||
|
||||
|
||||
def get_config(debug: bool = False) -> config_dict.ConfigDict:
|
||||
"""Get Jaxline experiment config."""
|
||||
config = base_config.get_base_config()
|
||||
# E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below)
|
||||
config.restore_path = config_dict.placeholder(str)
|
||||
|
||||
training_batch_size = 64
|
||||
eval_batch_size = 64
|
||||
|
||||
## Experiment config.
|
||||
loss_config_name = 'RegressionLossConfig'
|
||||
loss_kwargs = dict(
|
||||
exponent=1., # 2 for l2 loss, 1 for l1 loss, etc...
|
||||
)
|
||||
|
||||
dataset_config = dict(
|
||||
data_root=config_dict.placeholder(str),
|
||||
augment_with_random_mirror_symmetry=True,
|
||||
k_fold_split_id=config_dict.placeholder(int),
|
||||
num_k_fold_splits=config_dict.placeholder(int),
|
||||
# Options: "in" or "out".
|
||||
# Filter=in would keep the samples with nans in the conformer features.
|
||||
# Filter=out would keep the samples with no NaNs anywhere in the conformer
|
||||
# features.
|
||||
filter_in_or_out_samples_with_nans_in_conformers=(
|
||||
config_dict.placeholder(str)),
|
||||
cached_conformers_file=config_dict.placeholder(str))
|
||||
|
||||
model_config = dict(
|
||||
mlp_hidden_size=512,
|
||||
mlp_layers=2,
|
||||
latent_size=512,
|
||||
use_layer_norm=False,
|
||||
num_message_passing_steps=32,
|
||||
shared_message_passing_weights=False,
|
||||
mask_padding_graph_at_every_step=True,
|
||||
loss_config_name=loss_config_name,
|
||||
loss_kwargs=loss_kwargs,
|
||||
processor_mode='resnet',
|
||||
global_reducer='sum',
|
||||
node_reducer='sum',
|
||||
dropedge_rate=0.1,
|
||||
dropnode_rate=0.1,
|
||||
aux_multiplier=0.1,
|
||||
add_relative_distance=True,
|
||||
add_relative_displacement=True,
|
||||
add_absolute_positions=False,
|
||||
position_normalization=2.,
|
||||
relative_displacement_normalization=1.,
|
||||
ignore_globals=False,
|
||||
ignore_globals_from_final_layer_for_predictions=True,
|
||||
)
|
||||
|
||||
if debug:
|
||||
# Make network smaller.
|
||||
model_config.update(dict(
|
||||
mlp_hidden_size=32,
|
||||
mlp_layers=1,
|
||||
latent_size=32,
|
||||
num_message_passing_steps=1))
|
||||
|
||||
config.experiment_kwargs = config_dict.ConfigDict(
|
||||
dict(
|
||||
config=dict(
|
||||
debug=debug,
|
||||
predictions_dir=config_dict.placeholder(str),
|
||||
ema=True,
|
||||
ema_decay=0.9999,
|
||||
sample_random=0.05,
|
||||
optimizer=dict(
|
||||
name='adam',
|
||||
optimizer_kwargs=dict(b1=.9, b2=.95),
|
||||
lr_schedule=dict(
|
||||
warmup_steps=int(5e4),
|
||||
decay_steps=int(5e5),
|
||||
init_value=1e-5,
|
||||
peak_value=1e-4,
|
||||
end_value=0.,
|
||||
),
|
||||
),
|
||||
model=model_config,
|
||||
dataset_config=dataset_config,
|
||||
# As a rule of thumb, use the following statistics:
|
||||
# Avg. # nodes in graph: 16.
|
||||
# Avg. # edges in graph: 40.
|
||||
training=dict(
|
||||
dynamic_batch_size={
|
||||
'n_node': 256 if debug else 16 * training_batch_size,
|
||||
'n_edge': 512 if debug else 40 * training_batch_size,
|
||||
'n_graph': 2 if debug else training_batch_size,
|
||||
},),
|
||||
evaluation=dict(
|
||||
split='valid',
|
||||
dynamic_batch_size=dict(
|
||||
n_node=256 if debug else 16 * eval_batch_size,
|
||||
n_edge=512 if debug else 40 * eval_batch_size,
|
||||
n_graph=2 if debug else eval_batch_size,
|
||||
)))))
|
||||
|
||||
## Training loop config.
|
||||
config.training_steps = int(5e6)
|
||||
config.checkpoint_dir = '/tmp/checkpoint/pcq/'
|
||||
config.train_checkpoint_all_hosts = False
|
||||
config.save_checkpoint_interval = 300
|
||||
config.log_train_data_interval = 60
|
||||
config.log_tensors_interval = 60
|
||||
config.best_model_eval_metric = 'mae'
|
||||
config.best_model_eval_metric_higher_is_better = False
|
||||
|
||||
return config
|
||||
@@ -0,0 +1,324 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Conformer utilities."""
|
||||
|
||||
import copy
|
||||
from typing import List, Optional
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import rdkit
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import AllChem
|
||||
import tensorflow.compat.v2 as tf
|
||||
|
||||
|
||||
def generate_conformers(
|
||||
molecule: Chem.rdchem.Mol,
|
||||
max_num_conformers: int,
|
||||
*,
|
||||
random_seed: int = -1,
|
||||
prune_rms_thresh: float = -1.0,
|
||||
max_iter: int = -1,
|
||||
fallback_to_random: bool = False,
|
||||
) -> Chem.rdchem.Mol:
|
||||
"""Generates conformers for a given molecule.
|
||||
|
||||
Args:
|
||||
molecule: molecular representation of the compound.
|
||||
max_num_conformers: maximum number of conformers to generate. If pruning is
|
||||
done, the returned number of conformers is not guaranteed to match
|
||||
max_num_conformers.
|
||||
random_seed: random seed to use for conformer generation.
|
||||
prune_rms_thresh: RMSD threshold which allows to prune conformers that are
|
||||
too similar.
|
||||
max_iter: Maximum number of iterations to perform when optimising MMFF force
|
||||
field. If set to <= 0, energy optimisation is not performed.
|
||||
fallback_to_random: if conformers cannot be obtained, use random coordinates
|
||||
to initialise.
|
||||
|
||||
Returns:
|
||||
Copy of a `molecule` with added hydrogens. The returned molecule contains
|
||||
force field-optimised conformers. The number of conformers is guaranteed to
|
||||
be <= max_num_conformers.
|
||||
"""
|
||||
mol = copy.deepcopy(molecule)
|
||||
mol = Chem.AddHs(mol)
|
||||
mol = _embed_conformers(
|
||||
mol,
|
||||
max_num_conformers,
|
||||
random_seed,
|
||||
prune_rms_thresh,
|
||||
fallback_to_random,
|
||||
use_random=False)
|
||||
|
||||
if max_iter > 0:
|
||||
mol_with_conformers = _minimize_by_mmff(mol, max_iter)
|
||||
if mol_with_conformers is None:
|
||||
mol_with_conformers = _minimize_by_uff(mol, max_iter)
|
||||
else:
|
||||
mol_with_conformers = mol
|
||||
# Aligns conformations in a molecule to each other using the first
|
||||
# conformation as the reference.
|
||||
AllChem.AlignMolConformers(mol_with_conformers)
|
||||
|
||||
# We remove hydrogens to keep the number of atoms consistent with the graph
|
||||
# nodes.
|
||||
mol_with_conformers = Chem.RemoveHs(mol_with_conformers)
|
||||
|
||||
return mol_with_conformers
|
||||
|
||||
|
||||
def atom_to_feature_vector(
|
||||
atom: rdkit.Chem.rdchem.Atom,
|
||||
conformer: Optional[np.ndarray] = None,
|
||||
) -> List[float]:
|
||||
"""Converts rdkit atom object to feature list of indices.
|
||||
|
||||
Args:
|
||||
atom: rdkit atom object.
|
||||
conformer: Generated conformers. Returns -1 values if set to None.
|
||||
|
||||
Returns:
|
||||
List containing positions (x, y, z) of each atom from the conformer.
|
||||
"""
|
||||
if conformer:
|
||||
pos = conformer.GetAtomPosition(atom.GetIdx())
|
||||
return [pos.x, pos.y, pos.z]
|
||||
return [np.nan, np.nan, np.nan]
|
||||
|
||||
|
||||
def compute_conformer(smile: str, max_iter: int = -1) -> np.ndarray:
|
||||
"""Computes conformer.
|
||||
|
||||
Args:
|
||||
smile: Smile string.
|
||||
max_iter: Maximum number of iterations to perform when optimising MMFF force
|
||||
field. If set to <= 0, energy optimisation is not performed.
|
||||
|
||||
Returns:
|
||||
A tuple containing index, fingerprint and conformer.
|
||||
Raises:
|
||||
RuntimeError: If unable to convert smile string to RDKit mol.
|
||||
"""
|
||||
mol = rdkit.Chem.MolFromSmiles(smile)
|
||||
if not mol:
|
||||
raise RuntimeError('Unable to convert smile to molecule: %s' % smile)
|
||||
conformer_failed = False
|
||||
try:
|
||||
mol = generate_conformers(
|
||||
mol,
|
||||
max_num_conformers=1,
|
||||
random_seed=45,
|
||||
prune_rms_thresh=0.01,
|
||||
max_iter=max_iter)
|
||||
except IOError as e:
|
||||
logging.exception('Failed to generate conformers for %s . IOError %s.',
|
||||
smile, e)
|
||||
conformer_failed = True
|
||||
except ValueError:
|
||||
logging.error('Failed to generate conformers for %s . ValueError', smile)
|
||||
conformer_failed = True
|
||||
except: # pylint: disable=bare-except
|
||||
logging.error('Failed to generate conformers for %s.', smile)
|
||||
conformer_failed = True
|
||||
|
||||
atom_features_list = []
|
||||
conformer = None if conformer_failed else list(mol.GetConformers())[0]
|
||||
for atom in mol.GetAtoms():
|
||||
atom_features_list.append(atom_to_feature_vector(atom, conformer))
|
||||
conformer_features = np.array(atom_features_list, dtype=np.float32)
|
||||
return conformer_features
|
||||
|
||||
|
||||
def get_random_rotation_matrix(include_mirror_symmetry: bool) -> tf.Tensor:
|
||||
"""Returns a single random rotation matrix."""
|
||||
rotation_matrix = _get_random_rotation_3d()
|
||||
if include_mirror_symmetry:
|
||||
random_mirror_symmetry = _get_random_mirror_symmetry()
|
||||
rotation_matrix = tf.matmul(rotation_matrix, random_mirror_symmetry)
|
||||
|
||||
return rotation_matrix
|
||||
|
||||
|
||||
def rotate(vectors: tf.Tensor, rotation_matrix: tf.Tensor) -> tf.Tensor:
|
||||
"""Batch of vectors on a single rotation matrix."""
|
||||
return tf.matmul(vectors, rotation_matrix)
|
||||
|
||||
|
||||
def _embed_conformers(
|
||||
molecule: Chem.rdchem.Mol,
|
||||
max_num_conformers: int,
|
||||
random_seed: int,
|
||||
prune_rms_thresh: float,
|
||||
fallback_to_random: bool,
|
||||
*,
|
||||
use_random: bool = False,
|
||||
) -> Chem.rdchem.Mol:
|
||||
"""Embeds conformers into a copy of a molecule.
|
||||
|
||||
If random coordinates allowed, tries not to use random coordinates at first,
|
||||
and uses random only if fails.
|
||||
|
||||
Args:
|
||||
molecule: molecular representation of the compound.
|
||||
max_num_conformers: maximum number of conformers to generate. If pruning is
|
||||
done, the returned number of conformers is not guaranteed to match
|
||||
max_num_conformers.
|
||||
random_seed: random seed to use for conformer generation.
|
||||
prune_rms_thresh: RMSD threshold which allows to prune conformers that are
|
||||
too similar.
|
||||
fallback_to_random: if conformers cannot be obtained, use random coordinates
|
||||
to initialise.
|
||||
*:
|
||||
use_random: Use random coordinates. Shouldn't be set by any caller except
|
||||
this function itself.
|
||||
|
||||
Returns:
|
||||
A copy of a molecule with embedded conformers.
|
||||
|
||||
Raises:
|
||||
ValueError: if conformers cannot be obtained for a given molecule.
|
||||
"""
|
||||
mol = copy.deepcopy(molecule)
|
||||
|
||||
# Obtains parameters for conformer generation.
|
||||
# In particular, ETKDG is experimental-torsion basic knowledge distance
|
||||
# geometry, which allows to randomly generate an initial conformation that
|
||||
# satisfies various geometric constraints such as lower and upper bounds on
|
||||
# the distances between atoms.
|
||||
params = AllChem.ETKDGv3()
|
||||
|
||||
params.randomSeed = random_seed
|
||||
params.pruneRmsThresh = prune_rms_thresh
|
||||
params.numThreads = -1
|
||||
params.useRandomCoords = use_random
|
||||
|
||||
conf_ids = AllChem.EmbedMultipleConfs(mol, max_num_conformers, params)
|
||||
|
||||
if not conf_ids:
|
||||
if not fallback_to_random or use_random:
|
||||
raise ValueError('Cant get conformers')
|
||||
return _embed_conformers(
|
||||
mol,
|
||||
max_num_conformers,
|
||||
random_seed,
|
||||
prune_rms_thresh,
|
||||
fallback_to_random,
|
||||
use_random=True)
|
||||
return mol
|
||||
|
||||
|
||||
def _minimize_by_mmff(
|
||||
molecule: Chem.rdchem.Mol,
|
||||
max_iter: int,
|
||||
) -> Optional[Chem.rdchem.Mol]:
|
||||
"""Minimizes forcefield for conformers using MMFF algorithm.
|
||||
|
||||
Args:
|
||||
molecule: a datastructure containing conformers.
|
||||
max_iter: number of maximum iterations to use when optimising force field.
|
||||
|
||||
Returns:
|
||||
A copy of a `molecule` containing optimised conformers; or None if MMFF
|
||||
cannot be performed.
|
||||
"""
|
||||
molecule_props = AllChem.MMFFGetMoleculeProperties(molecule)
|
||||
if molecule_props is None:
|
||||
return None
|
||||
|
||||
mol = copy.deepcopy(molecule)
|
||||
for conf_id in range(mol.GetNumConformers()):
|
||||
ff = AllChem.MMFFGetMoleculeForceField(
|
||||
mol, molecule_props, confId=conf_id, ignoreInterfragInteractions=False)
|
||||
ff.Initialize()
|
||||
# minimises a conformer within a mol in place.
|
||||
ff.Minimize(max_iter)
|
||||
return mol
|
||||
|
||||
|
||||
def _minimize_by_uff(
|
||||
molecule: Chem.rdchem.Mol,
|
||||
max_iter: int,
|
||||
) -> Chem.rdchem.Mol:
|
||||
"""Minimizes forcefield for conformers using UFF algorithm.
|
||||
|
||||
Args:
|
||||
molecule: a datastructure containing conformers.
|
||||
max_iter: number of maximum iterations to use when optimising force field.
|
||||
|
||||
Returns:
|
||||
A copy of a `molecule` containing optimised conformers.
|
||||
"""
|
||||
mol = copy.deepcopy(molecule)
|
||||
conf_ids = range(mol.GetNumConformers())
|
||||
for conf_id in conf_ids:
|
||||
ff = AllChem.UFFGetMoleculeForceField(mol, confId=conf_id)
|
||||
ff.Initialize()
|
||||
# minimises a conformer within a mol in place.
|
||||
ff.Minimize(max_iter)
|
||||
return mol
|
||||
|
||||
|
||||
def _get_symmetry_rotation_matrix(sign: tf.Tensor) -> tf.Tensor:
|
||||
"""Returns the 2d/3d matrix for mirror symmetry."""
|
||||
zero = tf.zeros_like(sign)
|
||||
one = tf.ones_like(sign)
|
||||
# pylint: disable=bad-whitespace,bad-continuation
|
||||
rot = [sign, zero, zero,
|
||||
zero, one, zero,
|
||||
zero, zero, one]
|
||||
# pylint: enable=bad-whitespace,bad-continuation
|
||||
shape = (3, 3)
|
||||
rot = tf.stack(rot, axis=-1)
|
||||
rot = tf.reshape(rot, shape)
|
||||
return rot
|
||||
|
||||
|
||||
def _quaternion_to_rotation_matrix(quaternion: tf.Tensor) -> tf.Tensor:
|
||||
"""Converts a batch of quaternions to a batch of rotation matrices."""
|
||||
q0 = quaternion[0]
|
||||
q1 = quaternion[1]
|
||||
q2 = quaternion[2]
|
||||
q3 = quaternion[3]
|
||||
|
||||
r00 = 2 * (q0 * q0 + q1 * q1) - 1
|
||||
r01 = 2 * (q1 * q2 - q0 * q3)
|
||||
r02 = 2 * (q1 * q3 + q0 * q2)
|
||||
r10 = 2 * (q1 * q2 + q0 * q3)
|
||||
r11 = 2 * (q0 * q0 + q2 * q2) - 1
|
||||
r12 = 2 * (q2 * q3 - q0 * q1)
|
||||
r20 = 2 * (q1 * q3 - q0 * q2)
|
||||
r21 = 2 * (q2 * q3 + q0 * q1)
|
||||
r22 = 2 * (q0 * q0 + q3 * q3) - 1
|
||||
|
||||
matrix = tf.stack([r00, r01, r02,
|
||||
r10, r11, r12,
|
||||
r20, r21, r22], axis=-1)
|
||||
return tf.reshape(matrix, [3, 3])
|
||||
|
||||
|
||||
def _get_random_rotation_3d() -> tf.Tensor:
|
||||
random_quaternions = tf.random.normal(
|
||||
shape=[4], dtype=tf.float32)
|
||||
random_quaternions /= tf.linalg.norm(
|
||||
random_quaternions, axis=-1, keepdims=True)
|
||||
return _quaternion_to_rotation_matrix(random_quaternions)
|
||||
|
||||
|
||||
def _get_random_mirror_symmetry() -> tf.Tensor:
|
||||
random_0_1 = tf.random.uniform(
|
||||
shape=(), minval=0, maxval=2, dtype=tf.int32)
|
||||
random_signs = tf.cast((2 * random_0_1) - 1, tf.float32)
|
||||
return _get_symmetry_rotation_matrix(random_signs)
|
||||
@@ -0,0 +1,391 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Dataset utilities."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import jax
|
||||
import jraph
|
||||
from ml_collections import config_dict
|
||||
import numpy as np
|
||||
from ogb import utils
|
||||
from ogb.utils import features
|
||||
import tensorflow.compat.v2 as tf
|
||||
import tensorflow_datasets as tfds
|
||||
import tree
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
# pytype: disable=import-error
|
||||
import batching_utils
|
||||
import conformer_utils
|
||||
import datasets
|
||||
|
||||
|
||||
def build_dataset_iterator(
|
||||
data_root: str,
|
||||
split: str,
|
||||
dynamic_batch_size_config: config_dict.ConfigDict,
|
||||
sample_random: float,
|
||||
cached_conformers_file: str,
|
||||
debug: bool = False,
|
||||
is_training: bool = True,
|
||||
augment_with_random_mirror_symmetry: bool = False,
|
||||
positions_noise_std: Optional[float] = None,
|
||||
k_fold_split_id: Optional[int] = None,
|
||||
num_k_fold_splits: Optional[int] = None,
|
||||
filter_in_or_out_samples_with_nans_in_conformers: Optional[str] = None,
|
||||
):
|
||||
"""Returns an iterator over Batches from the dataset."""
|
||||
if debug:
|
||||
max_items_to_read_from_dataset = 10
|
||||
prefetch_buffer_size = 1
|
||||
shuffle_buffer_size = 1
|
||||
else:
|
||||
max_items_to_read_from_dataset = -1 # < 0 means no limit.
|
||||
prefetch_buffer_size = 64
|
||||
# It can take a while to fill the shuffle buffer with k fold splits.
|
||||
shuffle_buffer_size = 128 if k_fold_split_id is None else int(1e6)
|
||||
|
||||
num_local_devices = jax.local_device_count()
|
||||
|
||||
# Load all smile strings.
|
||||
indices, smiles, labels = _load_smiles(
|
||||
data_root,
|
||||
split,
|
||||
k_fold_split_id=k_fold_split_id,
|
||||
num_k_fold_splits=num_k_fold_splits)
|
||||
if debug:
|
||||
indices = indices[:100]
|
||||
smiles = smiles[:100]
|
||||
labels = labels[:100]
|
||||
# Generate all conformer features from smile strings ahead of time.
|
||||
# This gives us a boost from multi-parallelism as opposed to doing it
|
||||
# online.
|
||||
conformers = _load_conformers(indices, smiles, cached_conformers_file)
|
||||
|
||||
data_generator = (
|
||||
lambda: _get_pcq_graph_generator(indices, smiles, labels, conformers))
|
||||
# Create a dataset yielding graphs from smile strings.
|
||||
example = next(data_generator())
|
||||
signature_from_example = tree.map_structure(_numpy_to_tensor_spec, example)
|
||||
ds = tf.data.Dataset.from_generator(
|
||||
data_generator, output_signature=signature_from_example)
|
||||
|
||||
ds = ds.take(max_items_to_read_from_dataset)
|
||||
ds = ds.cache()
|
||||
if is_training:
|
||||
ds = ds.shuffle(shuffle_buffer_size)
|
||||
|
||||
# Apply transformations.
|
||||
def map_fn(graph, conformer_positions):
|
||||
graph = _maybe_one_hot_atoms_with_noise(
|
||||
graph, is_training=is_training, sample_random=sample_random)
|
||||
# Add conformer features.
|
||||
graph = _add_conformer_features(
|
||||
graph,
|
||||
conformer_positions,
|
||||
augment_with_random_mirror_symmetry=augment_with_random_mirror_symmetry,
|
||||
noise_std=positions_noise_std,
|
||||
is_training=is_training,
|
||||
)
|
||||
return _downcast_ints(graph)
|
||||
|
||||
ds = ds.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE)
|
||||
if filter_in_or_out_samples_with_nans_in_conformers:
|
||||
if filter_in_or_out_samples_with_nans_in_conformers not in ("in", "out"):
|
||||
raise ValueError(
|
||||
"Unknown value specified for the argument "
|
||||
"`filter_in_or_out_samples_with_nans_in_conformers`: %s" %
|
||||
filter_in_or_out_samples_with_nans_in_conformers)
|
||||
|
||||
filter_fn = _get_conformer_filter(
|
||||
with_nans=(filter_in_or_out_samples_with_nans_in_conformers == "in"))
|
||||
ds = ds.filter(filter_fn)
|
||||
|
||||
if is_training:
|
||||
ds = ds.shard(jax.process_count(), jax.process_index())
|
||||
ds = ds.repeat()
|
||||
|
||||
ds = ds.prefetch(prefetch_buffer_size)
|
||||
it = tfds.as_numpy(ds)
|
||||
|
||||
# Dynamic batching.
|
||||
batched_gen = batching_utils.dynamically_batch(
|
||||
it,
|
||||
n_node=dynamic_batch_size_config.n_node + 1,
|
||||
n_edge=dynamic_batch_size_config.n_edge,
|
||||
n_graph=dynamic_batch_size_config.n_graph + 1,
|
||||
)
|
||||
|
||||
if is_training:
|
||||
# Stack `num_local_devices` of batches together for pmap updates.
|
||||
batch_size = num_local_devices
|
||||
|
||||
def _batch(l):
|
||||
assert l
|
||||
return tree.map_structure(lambda *l: np.stack(l, axis=0), *l)
|
||||
|
||||
def batcher_fn():
|
||||
batch = []
|
||||
for sample in batched_gen:
|
||||
batch.append(sample)
|
||||
if len(batch) == batch_size:
|
||||
yield _batch(batch)
|
||||
batch = []
|
||||
if batch:
|
||||
yield _batch(batch)
|
||||
|
||||
for sample in batcher_fn():
|
||||
yield sample
|
||||
else:
|
||||
for sample in batched_gen:
|
||||
yield sample
|
||||
|
||||
|
||||
def _get_conformer_filter(with_nans: bool):
|
||||
"""Selects a conformer filter to apply.
|
||||
|
||||
Args:
|
||||
with_nans: Filter only selects samples with NaNs in conformer features.
|
||||
Else, selects samples without any NaNs in conformer features.
|
||||
|
||||
Returns:
|
||||
A function that can be used with tf.data.Dataset.filter().
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If the input graph to the filter has no conformer features to filter.
|
||||
"""
|
||||
|
||||
def _filter(graph: jraph.GraphsTuple) -> tf.Tensor:
|
||||
|
||||
if ("positions" not in graph.nodes) or (
|
||||
"positions_targets" not in graph.nodes) or (
|
||||
"positions_nan_mask" not in graph.globals):
|
||||
raise ValueError("Conformer features not available to filter.")
|
||||
|
||||
any_nan = tf.logical_not(tf.squeeze(graph.globals["positions_nan_mask"]))
|
||||
return any_nan if with_nans else tf.logical_not(any_nan)
|
||||
|
||||
return _filter
|
||||
|
||||
|
||||
def _numpy_to_tensor_spec(arr: np.ndarray) -> tf.TensorSpec:
|
||||
if not isinstance(arr, np.ndarray):
|
||||
return tf.TensorSpec([],
|
||||
dtype=tf.int32 if isinstance(arr, int) else tf.float32)
|
||||
elif arr.shape:
|
||||
return tf.TensorSpec((None,) + arr.shape[1:], arr.dtype)
|
||||
else:
|
||||
return tf.TensorSpec([], arr.dtype)
|
||||
|
||||
|
||||
def _sample_uniform_categorical(num: int, size: int) -> tf.Tensor:
|
||||
return tf.random.categorical(tf.math.log([[1 / size] * size]), num)[0]
|
||||
|
||||
|
||||
@jax.curry(jax.tree_map)
|
||||
def _downcast_ints(x):
|
||||
if x.dtype == tf.int64:
|
||||
return tf.cast(x, tf.int32)
|
||||
return x
|
||||
|
||||
|
||||
def _one_hot_atoms(atoms: tf.Tensor) -> tf.Tensor:
|
||||
vocab_sizes = features.get_atom_feature_dims()
|
||||
one_hots = []
|
||||
for i in range(atoms.shape[1]):
|
||||
one_hots.append(tf.one_hot(atoms[:, i], vocab_sizes[i], dtype=tf.float32))
|
||||
return tf.concat(one_hots, axis=-1)
|
||||
|
||||
|
||||
def _sample_one_hot_atoms(atoms: tf.Tensor) -> tf.Tensor:
|
||||
vocab_sizes = features.get_atom_feature_dims()
|
||||
one_hots = []
|
||||
num_atoms = tf.shape(atoms)[0]
|
||||
for i in range(atoms.shape[1]):
|
||||
sampled_category = _sample_uniform_categorical(num_atoms, vocab_sizes[i])
|
||||
one_hots.append(
|
||||
tf.one_hot(sampled_category, vocab_sizes[i], dtype=tf.float32))
|
||||
return tf.concat(one_hots, axis=-1)
|
||||
|
||||
|
||||
def _one_hot_bonds(bonds: tf.Tensor) -> tf.Tensor:
|
||||
vocab_sizes = features.get_bond_feature_dims()
|
||||
one_hots = []
|
||||
for i in range(bonds.shape[1]):
|
||||
one_hots.append(tf.one_hot(bonds[:, i], vocab_sizes[i], dtype=tf.float32))
|
||||
return tf.concat(one_hots, axis=-1)
|
||||
|
||||
|
||||
def _sample_one_hot_bonds(bonds: tf.Tensor) -> tf.Tensor:
|
||||
vocab_sizes = features.get_bond_feature_dims()
|
||||
one_hots = []
|
||||
num_bonds = tf.shape(bonds)[0]
|
||||
for i in range(bonds.shape[1]):
|
||||
sampled_category = _sample_uniform_categorical(num_bonds, vocab_sizes[i])
|
||||
one_hots.append(
|
||||
tf.one_hot(sampled_category, vocab_sizes[i], dtype=tf.float32))
|
||||
return tf.concat(one_hots, axis=-1)
|
||||
|
||||
|
||||
def _maybe_one_hot_atoms_with_noise(
|
||||
x,
|
||||
is_training: bool,
|
||||
sample_random: float,
|
||||
):
|
||||
"""One hot atoms with noise."""
|
||||
gt_nodes = _one_hot_atoms(x.nodes)
|
||||
gt_edges = _one_hot_bonds(x.edges)
|
||||
if is_training:
|
||||
num_nodes = tf.shape(x.nodes)[0]
|
||||
sample_node_or_not = tf.random.uniform([num_nodes],
|
||||
maxval=1) < sample_random
|
||||
nodes = tf.where(
|
||||
tf.expand_dims(sample_node_or_not, axis=-1),
|
||||
_sample_one_hot_atoms(x.nodes), gt_nodes)
|
||||
num_edges = tf.shape(x.edges)[0]
|
||||
sample_edges_or_not = tf.random.uniform([num_edges],
|
||||
maxval=1) < sample_random
|
||||
edges = tf.where(
|
||||
tf.expand_dims(sample_edges_or_not, axis=-1),
|
||||
_sample_one_hot_bonds(x.edges), gt_edges)
|
||||
else:
|
||||
nodes = gt_nodes
|
||||
edges = gt_edges
|
||||
return x._replace(
|
||||
nodes={
|
||||
"atom_one_hots_targets": gt_nodes,
|
||||
"atom_one_hots": nodes,
|
||||
},
|
||||
edges={
|
||||
"bond_one_hots_targets": gt_edges,
|
||||
"bond_one_hots": edges
|
||||
})
|
||||
|
||||
|
||||
def _load_smiles(
|
||||
data_root: str,
|
||||
split: str,
|
||||
k_fold_split_id: int,
|
||||
num_k_fold_splits: int,
|
||||
):
|
||||
"""Loads smiles trings for the input split."""
|
||||
|
||||
if split == "test" or k_fold_split_id is None:
|
||||
indices = datasets.load_splits()[split]
|
||||
elif split == "train":
|
||||
indices = datasets.load_all_except_kth_fold_indices(
|
||||
data_root, k_fold_split_id, num_k_fold_splits)
|
||||
else:
|
||||
assert split == "valid"
|
||||
indices = datasets.load_kth_fold_indices(data_root, k_fold_split_id)
|
||||
|
||||
smiles_and_labels = datasets.load_smile_strings(with_labels=True)
|
||||
smiles, labels = list(zip(*smiles_and_labels))
|
||||
return indices, [smiles[i] for i in indices], [labels[i] for i in indices]
|
||||
|
||||
|
||||
def _convert_ogb_graph_to_graphs_tuple(ogb_graph):
|
||||
"""Converts an OGB Graph to a GraphsTuple."""
|
||||
senders = ogb_graph["edge_index"][0]
|
||||
receivers = ogb_graph["edge_index"][1]
|
||||
edges = ogb_graph["edge_feat"]
|
||||
nodes = ogb_graph["node_feat"]
|
||||
n_node = np.array([ogb_graph["num_nodes"]])
|
||||
n_edge = np.array([len(senders)])
|
||||
graph = jraph.GraphsTuple(
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
senders=senders,
|
||||
receivers=receivers,
|
||||
n_node=n_node,
|
||||
n_edge=n_edge,
|
||||
globals=None)
|
||||
return tree.map_structure(lambda x: x if x is not None else np.array(0.),
|
||||
graph)
|
||||
|
||||
|
||||
def _load_conformers(indices: List[int],
|
||||
smiles: List[str],
|
||||
cached_conformers_file: str):
|
||||
"""Loads conformers."""
|
||||
smile_to_conformer = datasets.load_cached_conformers(cached_conformers_file)
|
||||
conformers = []
|
||||
for graph_idx, smile in zip(indices, smiles):
|
||||
del graph_idx # Unused.
|
||||
if smile not in smile_to_conformer:
|
||||
raise KeyError("Cache did not have conformer entry for the smile %s" %
|
||||
str(smile))
|
||||
conformers.append(dict(conformer=smile_to_conformer[smile]))
|
||||
return conformers
|
||||
|
||||
|
||||
def _add_conformer_features(
|
||||
graph,
|
||||
conformer_features,
|
||||
augment_with_random_mirror_symmetry: bool,
|
||||
noise_std: float,
|
||||
is_training: bool,
|
||||
):
|
||||
"""Adds conformer features."""
|
||||
if not isinstance(graph.nodes, dict):
|
||||
raise ValueError("Expected a dict type for `graph.nodes`.")
|
||||
# Remove mean position to center around a canonical origin.
|
||||
positions = conformer_features["conformer"]
|
||||
# NaN's appear in ~0.13% of training, 0.104% of validation and 0.16% of test
|
||||
# nodes.
|
||||
# See this colab: http://shortn/_6UcuosxY7x.
|
||||
nan_mask = tf.reduce_any(tf.math.is_nan(positions))
|
||||
|
||||
positions = tf.where(nan_mask, tf.constant(0., positions.dtype), positions)
|
||||
positions -= tf.reduce_mean(positions, axis=0, keepdims=True)
|
||||
|
||||
# Optionally augment with a random rotation.
|
||||
if is_training:
|
||||
rot_mat = conformer_utils.get_random_rotation_matrix(
|
||||
augment_with_random_mirror_symmetry)
|
||||
positions = conformer_utils.rotate(positions, rot_mat)
|
||||
positions_targets = positions
|
||||
|
||||
# Optionally add noise to the positions.
|
||||
if noise_std and is_training:
|
||||
positions = tf.random.normal(tf.shape(positions), positions, noise_std)
|
||||
|
||||
return graph._replace(
|
||||
nodes=dict(
|
||||
positions=positions,
|
||||
positions_targets=positions_targets,
|
||||
**graph.nodes),
|
||||
globals={
|
||||
"positions_nan_mask":
|
||||
tf.expand_dims(tf.logical_not(nan_mask), axis=0),
|
||||
**(graph.globals if isinstance(graph.globals, dict) else {})
|
||||
})
|
||||
|
||||
|
||||
def _get_pcq_graph_generator(indices, smiles, labels, conformers):
|
||||
"""Returns a generator to yield graph."""
|
||||
for idx, smile, conformer_positions, label in zip(indices, smiles, conformers,
|
||||
labels):
|
||||
graph = utils.smiles2graph(smile)
|
||||
graph = _convert_ogb_graph_to_graphs_tuple(graph)
|
||||
graph = graph._replace(
|
||||
globals={
|
||||
"target": np.array([label], dtype=np.float32),
|
||||
"graph_index": np.array([idx], dtype=np.int32),
|
||||
**(graph.globals if isinstance(graph.globals, dict) else {})
|
||||
})
|
||||
yield graph, conformer_positions
|
||||
@@ -0,0 +1,83 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""PCQM4M-LSC datasets."""
|
||||
|
||||
import functools
|
||||
import pickle
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from ogb import lsc
|
||||
|
||||
NUM_VALID_SAMPLES = 380_670
|
||||
NUM_TEST_SAMPLES = 377_423
|
||||
|
||||
NORMALIZE_TARGET_MEAN = 5.690944545356371
|
||||
NORMALIZE_TARGET_STD = 1.1561347795107815
|
||||
|
||||
|
||||
def load_splits() -> Dict[str, List[int]]:
|
||||
"""Loads dataset splits."""
|
||||
dataset = _get_pcq_dataset(only_smiles=True)
|
||||
return dataset.get_idx_split()
|
||||
|
||||
|
||||
def load_kth_fold_indices(data_root: str, k_fold_split_id: int) -> List[int]:
|
||||
"""Loads k-th fold indices."""
|
||||
fname = f"{data_root}/k_fold_splits/{k_fold_split_id}.pkl"
|
||||
return list(map(int, _load_pickle(fname)))
|
||||
|
||||
|
||||
def load_all_except_kth_fold_indices(data_root: str, k_fold_split_id: int,
|
||||
num_k_fold_splits: int) -> List[int]:
|
||||
"""Loads indices except for the kth fold."""
|
||||
if k_fold_split_id is None:
|
||||
raise ValueError("Expected integer value for `k_fold_split_id`.")
|
||||
indices = []
|
||||
for index in range(num_k_fold_splits):
|
||||
if index != k_fold_split_id:
|
||||
indices += load_kth_fold_indices(data_root, index)
|
||||
return indices
|
||||
|
||||
|
||||
def load_smile_strings(
|
||||
with_labels=False) -> List[Union[str, Tuple[str, np.ndarray]]]:
|
||||
"""Loads the smile strings in the PCQ dataset."""
|
||||
dataset = _get_pcq_dataset(only_smiles=True)
|
||||
smiles = []
|
||||
for i in range(len(dataset)):
|
||||
smile, label = dataset[i]
|
||||
if with_labels:
|
||||
smiles.append((smile, label))
|
||||
else:
|
||||
smiles.append(smile)
|
||||
|
||||
return smiles
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def load_cached_conformers(cached_fname: str) -> Dict[str, np.ndarray]:
|
||||
"""Returns cached dict mapping smile strings to conformer features."""
|
||||
return _load_pickle(cached_fname)
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_pcq_dataset(only_smiles: bool):
|
||||
return lsc.PCQM4MDataset(only_smiles=only_smiles)
|
||||
|
||||
|
||||
def _load_pickle(fname: str):
|
||||
with open(fname, "rb") as f:
|
||||
return pickle.load(f)
|
||||
@@ -0,0 +1,99 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Download data required for training and evaluating models."""
|
||||
|
||||
import pathlib
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
from google.cloud import storage
|
||||
|
||||
Path = pathlib.Path
|
||||
|
||||
|
||||
_BUCKET_NAME = 'deepmind-ogb-lsc'
|
||||
_MAX_DOWNLOAD_ATTEMPTS = 5
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_enum('payload', None, ['data', 'models'],
|
||||
'Download "data" or "models"?')
|
||||
flags.DEFINE_string('task_root', None, 'Local task root directory')
|
||||
|
||||
# GCS_DATA_ROOT = Path('pcq/data')
|
||||
# GCS_MODEL_ROOT = Path('pcq/models')
|
||||
|
||||
DATA_RELATIVE_PATHS = [
|
||||
'raw/data.csv.gz', 'preprocessed/smile_to_conformer.pkl',
|
||||
'k_fold_splits'
|
||||
]
|
||||
|
||||
|
||||
class DataCorruptionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _get_gcs_root():
|
||||
return Path('pcq') / FLAGS.payload
|
||||
|
||||
|
||||
def _get_gcs_bucket():
|
||||
storage_client = storage.Client.create_anonymous_client()
|
||||
return storage_client.bucket(_BUCKET_NAME)
|
||||
|
||||
|
||||
def _write_blob_to_destination(blob, task_root, ignore_existing=True):
|
||||
"""Write the blob."""
|
||||
logging.info("Copying blob: '%s'", blob.name)
|
||||
destination_path = Path(task_root) / Path(*Path(blob.name).parts[1:])
|
||||
logging.info(" ... to: '%s'", str(destination_path))
|
||||
if ignore_existing and destination_path.exists():
|
||||
return
|
||||
destination_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
checksum = 'crc32c'
|
||||
for attempt in range(_MAX_DOWNLOAD_ATTEMPTS):
|
||||
try:
|
||||
blob.download_to_filename(destination_path.as_posix(), checksum=checksum)
|
||||
except storage.client.resumable_media.common.DataCorruption:
|
||||
pass
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise DataCorruptionError(f"Checksum ('{checksum}') for {blob.name} failed "
|
||||
f'after {attempt + 1} attempts')
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
bucket = _get_gcs_bucket()
|
||||
if FLAGS.payload == 'data':
|
||||
relative_paths = DATA_RELATIVE_PATHS
|
||||
else:
|
||||
relative_paths = (None,)
|
||||
for relative_path in relative_paths:
|
||||
if relative_path is None:
|
||||
relative_path = str(_get_gcs_root())
|
||||
else:
|
||||
relative_path = str(_get_gcs_root() / relative_path)
|
||||
logging.info("Copying relative path: '%s'", relative_path)
|
||||
blobs = bucket.list_blobs(prefix=relative_path)
|
||||
for blob in blobs:
|
||||
_write_blob_to_destination(blob, FLAGS.task_root)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('payload')
|
||||
flags.mark_flag_as_required('task_root')
|
||||
app.run(main)
|
||||
@@ -0,0 +1,258 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Script to generate ensembled PCQ test predictions."""
|
||||
|
||||
import collections
|
||||
import os
|
||||
import pathlib
|
||||
from typing import List, NamedTuple
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import dill
|
||||
import numpy as np
|
||||
from ogb import lsc
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
# pytype: disable=import-error
|
||||
import datasets
|
||||
|
||||
_NUM_SEEDS = 2
|
||||
|
||||
_CLIP_VALUE = 20.
|
||||
|
||||
_NUM_KFOLD_SPLITS = 10
|
||||
|
||||
_SEED_START = flags.DEFINE_integer(
|
||||
'seed_start', 42, 'Initial seed for the list of ensemble models.')
|
||||
|
||||
_CONFORMER_PATH = flags.DEFINE_string(
|
||||
'conformer_path', None, 'Path to conformer predictions.', required=True)
|
||||
|
||||
_NON_CONFORMER_PATH = flags.DEFINE_string(
|
||||
'non_conformer_path',
|
||||
None,
|
||||
'Path to non-conformer predictions.',
|
||||
required=True)
|
||||
|
||||
_OUTPUT_PATH = flags.DEFINE_string('output_path', None, 'Output path.')
|
||||
|
||||
_SPLIT = flags.DEFINE_enum('split', 'test', ['test', 'valid'],
|
||||
'Split: valid or test.')
|
||||
|
||||
|
||||
class _Predictions(NamedTuple):
|
||||
predictions: np.ndarray
|
||||
indices: np.ndarray
|
||||
|
||||
|
||||
def _load_dill(fname) -> bytes:
|
||||
with open(fname, 'rb') as f:
|
||||
return dill.load(f)
|
||||
|
||||
|
||||
def _sort_by_indices(predictions: _Predictions) -> _Predictions:
|
||||
order = np.argsort(predictions.indices)
|
||||
return _Predictions(
|
||||
predictions=predictions.predictions[order],
|
||||
indices=predictions.indices[order])
|
||||
|
||||
|
||||
def load_predictions(path: str, split: str) -> _Predictions:
|
||||
"""Load written prediction file."""
|
||||
if len(os.listdir(path)) != 1:
|
||||
raise ValueError('Prediction directory must have exactly '
|
||||
'one prediction sub-directory: %s' % path)
|
||||
|
||||
prediction_subdir = os.listdir(path)[0]
|
||||
return _Predictions(*_load_dill(f'{path}/{prediction_subdir}/{split}.dill'))
|
||||
|
||||
|
||||
def mean_mae_distance(x, y):
|
||||
return np.abs(x - y).mean()
|
||||
|
||||
|
||||
def _load_valid_labels() -> np.ndarray:
|
||||
labels = [label for _, label in datasets.load_smile_strings(with_labels=True)]
|
||||
return np.array([labels[i] for i in datasets.load_splits()['valid']])
|
||||
|
||||
|
||||
def evaluate_valid_predictions(ensembled_predictions: _Predictions):
|
||||
"""Evaluates the predictions on the validation set."""
|
||||
ensembled_predictions = _sort_by_indices(ensembled_predictions)
|
||||
evaluator = lsc.PCQM4MEvaluator()
|
||||
results = evaluator.eval(
|
||||
dict(
|
||||
y_pred=ensembled_predictions.predictions,
|
||||
y_true=_load_valid_labels()))
|
||||
logging.info('MAE on validation dataset: %f', results['mae'])
|
||||
|
||||
|
||||
def clip_predictions(predictions: _Predictions) -> _Predictions:
|
||||
return predictions._replace(
|
||||
predictions=np.clip(predictions.predictions, 0., _CLIP_VALUE))
|
||||
|
||||
|
||||
def _generate_test_prediction_file(test_predictions: np.ndarray,
|
||||
output_path: pathlib.Path) -> pathlib.Path:
|
||||
"""Generates the final file for submission."""
|
||||
|
||||
# Check that predictions are not nuts.
|
||||
assert test_predictions.dtype in [np.float64, np.float32]
|
||||
assert not np.any(np.isnan(test_predictions))
|
||||
assert np.all(np.isfinite(test_predictions))
|
||||
assert test_predictions.min() >= 0.
|
||||
assert test_predictions.max() <= 40.
|
||||
|
||||
# Too risky to overwrite.
|
||||
if output_path.exists():
|
||||
raise ValueError(f'{output_path} already exists')
|
||||
|
||||
# Write to a local directory, and copy to final path (possibly cns).
|
||||
# It is not possible to write directlt on CNS.
|
||||
evaluator = lsc.PCQM4MEvaluator()
|
||||
|
||||
evaluator.save_test_submission(dict(y_pred=test_predictions), output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
def merge_complementary_results(split: str, results_a: _Predictions,
|
||||
results_b: _Predictions) -> _Predictions:
|
||||
"""Merges two prediction results with no overlap."""
|
||||
|
||||
indices_a = set(results_a.indices)
|
||||
indices_b = set(results_b.indices)
|
||||
assert not indices_a.intersection(indices_b)
|
||||
|
||||
if split == 'test':
|
||||
merged_indices = list(sorted(indices_a | indices_b))
|
||||
expected_indices = datasets.load_splits()[split]
|
||||
assert np.all(expected_indices == merged_indices)
|
||||
|
||||
predictions = np.concatenate([results_a.predictions, results_b.predictions])
|
||||
indices = np.concatenate([results_a.indices, results_b.indices])
|
||||
predictions = _sort_by_indices(
|
||||
_Predictions(indices=indices, predictions=predictions))
|
||||
return predictions
|
||||
|
||||
|
||||
def ensemble_valid_predictions(
|
||||
predictions_list: List[_Predictions]) -> _Predictions:
|
||||
"""Ensembles a list of predictions."""
|
||||
index_to_predictions = collections.defaultdict(list)
|
||||
for predictions in predictions_list:
|
||||
for idx, pred in zip(predictions.indices, predictions.predictions):
|
||||
index_to_predictions[idx].append(pred)
|
||||
|
||||
for idx, ensemble_list in index_to_predictions.items():
|
||||
if len(ensemble_list) != _NUM_SEEDS:
|
||||
raise RuntimeError(
|
||||
'Graph index in the validation set received wrong number of '
|
||||
'predictions to ensemble.')
|
||||
|
||||
index_to_predictions = {
|
||||
k: np.median(pred_list, axis=0)
|
||||
for k, pred_list in index_to_predictions.items()
|
||||
}
|
||||
return _sort_by_indices(
|
||||
_Predictions(
|
||||
indices=np.array(list(index_to_predictions.keys())),
|
||||
predictions=np.array(list(index_to_predictions.values()))))
|
||||
|
||||
|
||||
def ensemble_test_predictions(
|
||||
predictions_list: List[_Predictions]) -> _Predictions:
|
||||
"""Ensembles a list of predictions."""
|
||||
predictions = np.median([pred.predictions for pred in predictions_list],
|
||||
axis=0)
|
||||
common_indices = predictions_list[0].indices
|
||||
for preds in predictions_list[1:]:
|
||||
assert np.all(preds.indices == common_indices)
|
||||
|
||||
return _Predictions(predictions=predictions, indices=common_indices)
|
||||
|
||||
|
||||
def create_submission_from_predictions(
|
||||
output_path: pathlib.Path, test_predictions: _Predictions) -> pathlib.Path:
|
||||
"""Creates a submission for predictions on a path."""
|
||||
assert _SPLIT.value == 'test'
|
||||
|
||||
output_path = _generate_test_prediction_file(
|
||||
test_predictions.predictions,
|
||||
output_path=output_path / 'submission_files')
|
||||
|
||||
return output_path / 'y_pred_pcqm4m.npz'
|
||||
|
||||
|
||||
def merge_predictions(split: str) -> List[_Predictions]:
|
||||
"""Generates features merged from conformer and non-conformer predictions."""
|
||||
merged_predictions: List[_Predictions] = []
|
||||
seed = _SEED_START.value
|
||||
|
||||
# Load conformer and non-conformer predictions.
|
||||
for unused_seed_group in (0, 1):
|
||||
for k in range(_NUM_KFOLD_SPLITS):
|
||||
conformer_predictions: _Predictions = load_predictions(
|
||||
f'{_CONFORMER_PATH.value}/k{k}_seed{seed}', split)
|
||||
|
||||
non_conformer_predictions: _Predictions = load_predictions(
|
||||
f'{_NON_CONFORMER_PATH.value}/k{k}_seed{seed}', split)
|
||||
|
||||
merged_predictions.append(
|
||||
merge_complementary_results(_SPLIT.value, conformer_predictions,
|
||||
non_conformer_predictions))
|
||||
|
||||
seed += 1
|
||||
return merged_predictions
|
||||
|
||||
|
||||
def main(_):
|
||||
split: str = _SPLIT.value
|
||||
|
||||
# Merge conformer and non-conformer predictions.
|
||||
merged_predictions = merge_predictions(split)
|
||||
|
||||
# Clip before ensembling.
|
||||
clipped_predictions = list(map(clip_predictions, merged_predictions))
|
||||
|
||||
# Ensemble predictions.
|
||||
if split == 'valid':
|
||||
ensembled_predictions = ensemble_valid_predictions(clipped_predictions)
|
||||
else:
|
||||
assert split == 'test'
|
||||
ensembled_predictions = ensemble_test_predictions(clipped_predictions)
|
||||
|
||||
# Clip after ensembling.
|
||||
ensembled_predictions = clip_predictions(ensembled_predictions)
|
||||
|
||||
ensembled_predictions_path = pathlib.Path(_OUTPUT_PATH.value)
|
||||
ensembled_predictions_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(ensembled_predictions_path / f'{split}_predictions.dill',
|
||||
'wb') as f:
|
||||
dill.dump(ensembled_predictions, f)
|
||||
|
||||
if split == 'valid':
|
||||
evaluate_valid_predictions(ensembled_predictions)
|
||||
else:
|
||||
assert split == 'test'
|
||||
output_path = create_submission_from_predictions(ensembled_predictions_path,
|
||||
ensembled_predictions)
|
||||
logging.info('Submission files written to %s', output_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,449 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""PCQM4M-LSC Jaxline experiment."""
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
from typing import Iterable, Mapping, NamedTuple, Tuple
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import chex
|
||||
import dill
|
||||
import haiku as hk
|
||||
import jax
|
||||
from jax.config import config as jax_config
|
||||
import jax.numpy as jnp
|
||||
from jaxline import experiment
|
||||
from jaxline import platform
|
||||
from jaxline import utils
|
||||
import jraph
|
||||
import numpy as np
|
||||
import optax
|
||||
import tensorflow as tf
|
||||
import tree
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
import dataset_utils
|
||||
import datasets
|
||||
import model
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def _get_step_date_label(global_step: int):
|
||||
# Date removing microseconds.
|
||||
date_str = datetime.datetime.now().isoformat().split('.')[0]
|
||||
return f'step_{global_step}_{date_str}'
|
||||
|
||||
|
||||
class _Predictions(NamedTuple):
|
||||
predictions: np.ndarray
|
||||
indices: np.ndarray
|
||||
|
||||
|
||||
def tf1_ema(ema_value, current_value, decay, step):
|
||||
"""Implements EMA with TF1-style decay warmup."""
|
||||
decay = jnp.minimum(decay, (1.0 + step) / (10.0 + step))
|
||||
return ema_value * decay + current_value * (1 - decay)
|
||||
|
||||
|
||||
def _sort_predictions_by_indices(predictions: _Predictions):
|
||||
sorted_order = np.argsort(predictions.indices)
|
||||
return _Predictions(
|
||||
predictions=predictions.predictions[sorted_order],
|
||||
indices=predictions.indices[sorted_order])
|
||||
|
||||
|
||||
class Experiment(experiment.AbstractExperiment):
|
||||
"""OGB Graph Property Prediction GraphNet experiment."""
|
||||
|
||||
CHECKPOINT_ATTRS = {
|
||||
'_params': 'params',
|
||||
'_opt_state': 'opt_state',
|
||||
'_network_state': 'network_state',
|
||||
'_ema_network_state': 'ema_network_state',
|
||||
'_ema_params': 'ema_params',
|
||||
}
|
||||
|
||||
def __init__(self, mode, init_rng, config):
|
||||
"""Initializes experiment."""
|
||||
super(Experiment, self).__init__(mode=mode, init_rng=init_rng)
|
||||
if mode not in ('train', 'eval', 'train_eval_multithreaded'):
|
||||
raise ValueError(f'Invalid mode {mode}.')
|
||||
|
||||
# Do not use accelerators in data pipeline.
|
||||
tf.config.experimental.set_visible_devices([], device_type='gpu')
|
||||
tf.config.experimental.set_visible_devices([], device_type='tpu')
|
||||
|
||||
self.mode = mode
|
||||
self.init_rng = init_rng
|
||||
self.config = config
|
||||
|
||||
self.loss = None
|
||||
self.forward = None
|
||||
|
||||
# Needed for checkpoint restore.
|
||||
self._params = None
|
||||
self._network_state = None
|
||||
self._opt_state = None
|
||||
self._ema_network_state = None
|
||||
self._ema_params = None
|
||||
|
||||
# _ _
|
||||
# | |_ _ __ __ _(_)_ __
|
||||
# | __| "__/ _` | | "_ \
|
||||
# | |_| | | (_| | | | | |
|
||||
# \__|_| \__,_|_|_| |_|
|
||||
#
|
||||
|
||||
def step(self, global_step: jnp.ndarray, rng: jnp.ndarray, **unused_args):
|
||||
"""See Jaxline base class."""
|
||||
if self.loss is None:
|
||||
self._train_init()
|
||||
|
||||
graph = next(self._train_input)
|
||||
out = self.update_parameters(
|
||||
self._params,
|
||||
self._ema_params,
|
||||
self._network_state,
|
||||
self._ema_network_state,
|
||||
self._opt_state,
|
||||
global_step,
|
||||
rng,
|
||||
graph._asdict())
|
||||
(self._params, self._ema_params, self._network_state,
|
||||
self._ema_network_state, self._opt_state, scalars) = out
|
||||
return utils.get_first(scalars)
|
||||
|
||||
def _construct_loss_config(self):
|
||||
loss_config = getattr(model, self.config.model.loss_config_name)
|
||||
if self.config.model.loss_config_name == 'RegressionLossConfig':
|
||||
return loss_config(
|
||||
mean=datasets.NORMALIZE_TARGET_MEAN,
|
||||
std=datasets.NORMALIZE_TARGET_STD,
|
||||
kwargs=self.config.model.loss_kwargs)
|
||||
else:
|
||||
raise ValueError('Unknown Loss Config')
|
||||
|
||||
def _train_init(self):
|
||||
self.loss = hk.transform_with_state(self._loss)
|
||||
self._train_input = utils.py_prefetch(
|
||||
lambda: self._build_numpy_dataset_iterator('train', is_training=True))
|
||||
init_stacked_graphs = next(self._train_input)
|
||||
init_key = utils.bcast_local_devices(self.init_rng)
|
||||
p_init = jax.pmap(self.loss.init)
|
||||
self._params, self._network_state = p_init(init_key,
|
||||
**init_stacked_graphs._asdict())
|
||||
|
||||
# Learning rate scheduling.
|
||||
lr_schedule = optax.warmup_cosine_decay_schedule(
|
||||
**self.config.optimizer.lr_schedule)
|
||||
|
||||
self.optimizer = getattr(optax, self.config.optimizer.name)(
|
||||
learning_rate=lr_schedule, **self.config.optimizer.optimizer_kwargs)
|
||||
|
||||
self._opt_state = jax.pmap(self.optimizer.init)(self._params)
|
||||
self.update_parameters = jax.pmap(self._update_parameters, axis_name='i')
|
||||
if self.config.ema:
|
||||
self._ema_params = self._params
|
||||
self._ema_network_state = self._network_state
|
||||
|
||||
def _loss(
|
||||
self, **graph: Mapping[str, chex.ArrayTree]) -> chex.ArrayTree:
|
||||
|
||||
graph = jraph.GraphsTuple(**graph)
|
||||
model_instance = model.GraphPropertyEncodeProcessDecode(
|
||||
loss_config=self._construct_loss_config(), **self.config.model)
|
||||
loss, scalars = model_instance.get_loss(graph)
|
||||
return loss, scalars
|
||||
|
||||
def _maybe_save_predictions(
|
||||
self,
|
||||
predictions: jnp.ndarray,
|
||||
split: str,
|
||||
global_step: jnp.ndarray,
|
||||
):
|
||||
if not self.config.predictions_dir:
|
||||
return
|
||||
output_dir = os.path.join(self.config.predictions_dir,
|
||||
_get_step_date_label(global_step))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_path = os.path.join(output_dir, split + '.dill')
|
||||
|
||||
with open(output_path, 'wb') as f:
|
||||
dill.dump(predictions, f)
|
||||
logging.info('Saved %s predictions at: %s', split, output_path)
|
||||
|
||||
def _build_numpy_dataset_iterator(self, split: str, is_training: bool):
|
||||
dynamic_batch_size_config = (
|
||||
self.config.training.dynamic_batch_size
|
||||
if is_training else self.config.evaluation.dynamic_batch_size)
|
||||
|
||||
return dataset_utils.build_dataset_iterator(
|
||||
split=split,
|
||||
dynamic_batch_size_config=dynamic_batch_size_config,
|
||||
sample_random=self.config.sample_random,
|
||||
debug=self.config.debug,
|
||||
is_training=is_training,
|
||||
**self.config.dataset_config)
|
||||
|
||||
def _update_parameters(
|
||||
self,
|
||||
params: hk.Params,
|
||||
ema_params: hk.Params,
|
||||
network_state: hk.State,
|
||||
ema_network_state: hk.State,
|
||||
opt_state: optax.OptState,
|
||||
global_step: jnp.ndarray,
|
||||
rng: jnp.ndarray,
|
||||
graph: jraph.GraphsTuple,
|
||||
) -> Tuple[hk.Params, hk.Params, hk.State, hk.State, optax.OptState,
|
||||
chex.ArrayTree]:
|
||||
"""Updates parameters."""
|
||||
def get_loss(*x, **graph):
|
||||
(loss, scalars), network_state = self.loss.apply(*x, **graph)
|
||||
return loss, (scalars, network_state)
|
||||
grad_loss_fn = jax.grad(get_loss, has_aux=True)
|
||||
scaled_grads, (scalars, network_state) = grad_loss_fn(
|
||||
params, network_state, rng, **graph)
|
||||
grads = jax.lax.psum(scaled_grads, axis_name='i')
|
||||
updates, opt_state = self.optimizer.update(grads, opt_state, params)
|
||||
params = optax.apply_updates(params, updates)
|
||||
if ema_params is not None:
|
||||
ema = lambda x, y: tf1_ema(x, y, self.config.ema_decay, global_step)
|
||||
ema_params = jax.tree_multimap(ema, ema_params, params)
|
||||
ema_network_state = jax.tree_multimap(ema, ema_network_state,
|
||||
network_state)
|
||||
return params, ema_params, network_state, ema_network_state, opt_state, scalars
|
||||
|
||||
# _
|
||||
# _____ ____ _| |
|
||||
# / _ \ \ / / _` | |
|
||||
# | __/\ V / (_| | |
|
||||
# \___| \_/ \__,_|_|
|
||||
#
|
||||
|
||||
def evaluate(self, global_step: jnp.ndarray, rng: jnp.ndarray,
|
||||
**unused_kwargs) -> chex.ArrayTree:
|
||||
"""See Jaxline base class."""
|
||||
if self.forward is None:
|
||||
self._eval_init()
|
||||
|
||||
if self.config.ema:
|
||||
params = utils.get_first(self._ema_params)
|
||||
state = utils.get_first(self._ema_network_state)
|
||||
else:
|
||||
params = utils.get_first(self._params)
|
||||
state = utils.get_first(self._network_state)
|
||||
rng = utils.get_first(rng)
|
||||
|
||||
split = self.config.evaluation.split
|
||||
predictions, scalars = self._get_predictions(
|
||||
params, state, rng,
|
||||
utils.py_prefetch(
|
||||
functools.partial(
|
||||
self._build_numpy_dataset_iterator, split, is_training=False)))
|
||||
self._maybe_save_predictions(predictions, split, global_step[0])
|
||||
return scalars
|
||||
|
||||
def _sum_regression_scalars(self, preds: jnp.ndarray,
|
||||
graph: jraph.GraphsTuple) -> chex.ArrayTree:
|
||||
"""Creates unnormalised values for accumulation."""
|
||||
targets = graph.globals['target']
|
||||
graph_mask = jraph.get_graph_padding_mask(graph)
|
||||
# Sum for accumulation, normalise later since there are a
|
||||
# variable number of graphs per batch.
|
||||
mae = model.sum_with_mask(jnp.abs(targets - preds), graph_mask)
|
||||
mse = model.sum_with_mask((targets - preds)**2, graph_mask)
|
||||
count = jnp.sum(graph_mask)
|
||||
return {'values': {'mae': mae.item(), 'mse': mse.item()},
|
||||
'counts': {'mae': count.item(), 'mse': count.item()}}
|
||||
|
||||
def _get_prediction(
|
||||
self,
|
||||
params: hk.Params,
|
||||
state: hk.State,
|
||||
rng: jnp.ndarray,
|
||||
graph: jraph.GraphsTuple,
|
||||
) -> np.ndarray:
|
||||
"""Returns predictions for all the graphs in the dataset split."""
|
||||
model_out, _ = self.eval_apply(params, state, rng, **graph._asdict())
|
||||
prediction = np.squeeze(model_out['globals'], axis=1)
|
||||
return prediction
|
||||
|
||||
def _get_predictions(
|
||||
self,
|
||||
params: hk.Params,
|
||||
state: hk.State,
|
||||
rng: jnp.ndarray,
|
||||
graph_iterator: Iterable[jraph.GraphsTuple],
|
||||
) -> Tuple[_Predictions, chex.ArrayTree]:
|
||||
all_scalars = []
|
||||
predictions = []
|
||||
graph_indices = []
|
||||
for i, graph in enumerate(graph_iterator):
|
||||
prediction = self._get_prediction(params, state, rng, graph)
|
||||
if 'target' in graph.globals and not jnp.isnan(
|
||||
graph.globals['target']).any():
|
||||
scalars = self._sum_regression_scalars(prediction, graph)
|
||||
all_scalars.append(scalars)
|
||||
num_padding_graphs = jraph.get_number_of_padding_with_graphs_graphs(graph)
|
||||
num_valid_graphs = len(graph.n_node) - num_padding_graphs
|
||||
depadded_prediction = prediction[:num_valid_graphs]
|
||||
predictions.append(depadded_prediction)
|
||||
graph_indices.append(graph.globals['graph_index'][:num_valid_graphs])
|
||||
|
||||
if i % 1000 == 0:
|
||||
logging.info('Generated predictions for %d batches so far', i + 1)
|
||||
|
||||
predictions = _sort_predictions_by_indices(
|
||||
_Predictions(
|
||||
predictions=np.concatenate(predictions),
|
||||
indices=np.concatenate(graph_indices)))
|
||||
|
||||
if all_scalars:
|
||||
sum_all_args = lambda *l: sum(l)
|
||||
# Sum over graphs in the dataset.
|
||||
accum_scalars = tree.map_structure(sum_all_args, *all_scalars)
|
||||
scalars = tree.map_structure(lambda x, y: x / y, accum_scalars['values'],
|
||||
accum_scalars['counts'])
|
||||
else:
|
||||
scalars = {}
|
||||
return predictions, scalars
|
||||
|
||||
def _eval_init(self):
|
||||
self.forward = hk.transform_with_state(self._forward)
|
||||
self.eval_apply = jax.jit(self.forward.apply)
|
||||
|
||||
def _forward(self, **graph: Mapping[str, chex.ArrayTree]) -> chex.ArrayTree:
|
||||
|
||||
graph = jraph.GraphsTuple(**graph)
|
||||
model_instance = model.GraphPropertyEncodeProcessDecode(
|
||||
loss_config=self._construct_loss_config(), **self.config.model)
|
||||
return model_instance(graph)
|
||||
|
||||
|
||||
def _restore_state_to_in_memory_checkpointer(restore_path):
|
||||
"""Initializes experiment state from a checkpoint."""
|
||||
|
||||
# Load pretrained experiment state.
|
||||
python_state_path = os.path.join(restore_path, 'checkpoint.dill')
|
||||
with open(python_state_path, 'rb') as f:
|
||||
pretrained_state = dill.load(f)
|
||||
logging.info('Restored checkpoint from %s', python_state_path)
|
||||
|
||||
# Assign state to a dummy experiment instance for the in-memory checkpointer,
|
||||
# broadcasting to devices.
|
||||
dummy_experiment = Experiment(
|
||||
mode='train', init_rng=0, config=FLAGS.config.experiment_kwargs.config)
|
||||
for attribute, key in Experiment.CHECKPOINT_ATTRS.items():
|
||||
setattr(dummy_experiment, attribute,
|
||||
utils.bcast_local_devices(pretrained_state[key]))
|
||||
|
||||
jaxline_state = dict(
|
||||
global_step=pretrained_state['global_step'],
|
||||
experiment_module=dummy_experiment)
|
||||
snapshot = utils.SnapshotNT(0, jaxline_state)
|
||||
|
||||
# Finally, seed the jaxline `utils.InMemoryCheckpointer` global dict.
|
||||
utils.GLOBAL_CHECKPOINT_DICT['latest'] = utils.CheckpointNT(
|
||||
threading.local(), [snapshot])
|
||||
|
||||
|
||||
def _save_state_from_in_memory_checkpointer(
|
||||
save_path, experiment_class: experiment.AbstractExperiment):
|
||||
"""Saves experiment state to a checkpoint."""
|
||||
logging.info('Saving model.')
|
||||
for checkpoint_name, checkpoint in utils.GLOBAL_CHECKPOINT_DICT.items():
|
||||
if not checkpoint.history:
|
||||
logging.info('Nothing to save in "%s"', checkpoint_name)
|
||||
continue
|
||||
|
||||
pickle_nest = checkpoint.history[-1].pickle_nest
|
||||
global_step = pickle_nest['global_step']
|
||||
|
||||
state_dict = {'global_step': global_step}
|
||||
for attribute, key in experiment_class.CHECKPOINT_ATTRS.items():
|
||||
state_dict[key] = utils.get_first(
|
||||
getattr(pickle_nest['experiment_module'], attribute))
|
||||
save_dir = os.path.join(
|
||||
save_path, checkpoint_name, _get_step_date_label(global_step))
|
||||
python_state_path = os.path.join(save_dir, 'checkpoint.dill')
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
with open(python_state_path, 'wb') as f:
|
||||
dill.dump(state_dict, f)
|
||||
logging.info(
|
||||
'Saved "%s" checkpoint to %s', checkpoint_name, python_state_path)
|
||||
|
||||
|
||||
def _setup_signals(save_model_fn):
|
||||
"""Sets up a signal for model saving."""
|
||||
# Save a model on Ctrl+C.
|
||||
def sigint_handler(unused_sig, unused_frame):
|
||||
# Ideally, rather than saving immediately, we would then "wait" for a good
|
||||
# time to save. In practice this reads from an in-memory checkpoint that
|
||||
# only saves every 30 seconds or so, so chances of race conditions are very
|
||||
# small.
|
||||
save_model_fn()
|
||||
logging.info(r'Use `Ctrl+\` to save and exit.')
|
||||
|
||||
# Exit on `Ctrl+\`, saving a model.
|
||||
prev_sigquit_handler = signal.getsignal(signal.SIGQUIT)
|
||||
def sigquit_handler(unused_sig, unused_frame):
|
||||
# Restore previous handler early, just in case something goes wrong in the
|
||||
# next lines, so it is possible to press again and exit.
|
||||
signal.signal(signal.SIGQUIT, prev_sigquit_handler)
|
||||
save_model_fn()
|
||||
logging.info(r'Exiting on `Ctrl+\`')
|
||||
|
||||
# Re-raise for clean exit.
|
||||
os.kill(os.getpid(), signal.SIGQUIT)
|
||||
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
signal.signal(signal.SIGQUIT, sigquit_handler)
|
||||
|
||||
|
||||
def main(argv, experiment_class: experiment.AbstractExperiment):
|
||||
|
||||
# Maybe restore a model.
|
||||
restore_path = FLAGS.config.restore_path
|
||||
if restore_path:
|
||||
_restore_state_to_in_memory_checkpointer(restore_path)
|
||||
|
||||
# Maybe save a model.
|
||||
save_dir = os.path.join(FLAGS.config.checkpoint_dir, 'models')
|
||||
if FLAGS.config.one_off_evaluate:
|
||||
save_model_fn = lambda: None # No need to save checkpoint in this case.
|
||||
else:
|
||||
save_model_fn = functools.partial(
|
||||
_save_state_from_in_memory_checkpointer, save_dir, experiment_class)
|
||||
_setup_signals(save_model_fn) # Save on Ctrl+C (continue) or Ctrl+\ (exit).
|
||||
|
||||
try:
|
||||
platform.main(experiment_class, argv)
|
||||
finally:
|
||||
save_model_fn() # Save at the end of training or in case of exception.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
jax_config.update('jax_debug_nans', False)
|
||||
flags.mark_flag_as_required('config')
|
||||
app.run(lambda argv: main(argv, Experiment))
|
||||
@@ -0,0 +1,67 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Generate conformer features to be used for training/predictions."""
|
||||
|
||||
import multiprocessing as mp
|
||||
import pickle
|
||||
from typing import List
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
import conformer_utils
|
||||
import datasets
|
||||
|
||||
_SPLITS = flags.DEFINE_spaceseplist(
|
||||
'splits', ['test'], 'Splits to compute conformer features for.')
|
||||
|
||||
_OUTPUT_FILE = flags.DEFINE_string(
|
||||
'output_file',
|
||||
None,
|
||||
required=True,
|
||||
help='Output file name to write the generated conformer features to.')
|
||||
|
||||
_NUM_PROCS = flags.DEFINE_integer(
|
||||
'num_parallel_procs', 64,
|
||||
'Number of parallel processes to use for conformer generation.')
|
||||
|
||||
|
||||
def generate_conformer_features(smiles: List[str]) -> List[np.ndarray]:
|
||||
# Conformer generation is a CPU-bound task and hence can get a boost from
|
||||
# parallel processing.
|
||||
# To avoid GIL, we choose multiprocessing instead of the
|
||||
# simpler multi-threading option here for parallel computing.
|
||||
with mp.Pool(_NUM_PROCS.value) as pool:
|
||||
return list(pool.map(conformer_utils.compute_conformer, smiles))
|
||||
|
||||
|
||||
def main(_):
|
||||
smiles = datasets.load_smile_strings(with_labels=False)
|
||||
indices = set()
|
||||
for split in _SPLITS.value:
|
||||
indices.update(datasets.load_splits()[split])
|
||||
|
||||
smiles = [smiles[i] for i in sorted(indices)]
|
||||
conformers = generate_conformer_features(smiles)
|
||||
smiles_to_conformers = dict(zip(smiles, conformers))
|
||||
|
||||
with open(_OUTPUT_FILE.value, 'wb') as f:
|
||||
pickle.dump(smiles_to_conformers, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Generates the k-fold validation splits."""
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
import datasets
|
||||
|
||||
|
||||
_OUTPUT_DIR = flags.DEFINE_string(
|
||||
'output_dir', None, required=True,
|
||||
help='Output directory to write the splits to')
|
||||
|
||||
|
||||
K = 10
|
||||
|
||||
|
||||
def main(argv):
|
||||
del argv
|
||||
valid_indices = datasets.load_splits()['valid']
|
||||
k_splits = np.split(valid_indices, K)
|
||||
os.makedirs(_OUTPUT_DIR.value, exist_ok=True)
|
||||
for k_i, split in enumerate(k_splits):
|
||||
fname = os.path.join(_OUTPUT_DIR.value, f'{k_i}.pkl')
|
||||
with open(fname, 'wb') as f:
|
||||
pickle.dump(split, f)
|
||||
logging.info('Saved: %s', fname)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,21 @@
|
||||
wheel
|
||||
absl-py>=0.12.0
|
||||
chex>=0.0.7
|
||||
dill>=0.3.3
|
||||
dm-haiku>=0.0.4
|
||||
# jaxline
|
||||
git+https://github.com/deepmind/jaxline@927695a0b88e480d085590736e2daa36c2f2c68b
|
||||
optax>=0.0.8
|
||||
ml_collections
|
||||
numpy>=1.16.4
|
||||
tensorflow>=2.5.0
|
||||
tensorflow-datasets>=4.3.0
|
||||
dm-tree>=0.1.6
|
||||
ogb>=1.3.1
|
||||
# graph_nets
|
||||
git+git://github.com/deepmind/graph_nets.git@64771dff0d74ca8e77b1f1dcd5a7d26634356d61
|
||||
rdkit-pypi>=2021.3.2.3
|
||||
jaxlib>=0.1.57+cuda110
|
||||
# jraph
|
||||
git+git://github.com/deepmind/jraph.git@80d2f9e4d82f841a71d56ba44fa9e8781e93ae1f
|
||||
google-cloud-storage
|
||||
Executable
+48
@@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
while getopts ":r:" opt; do
|
||||
case ${opt} in
|
||||
r )
|
||||
TASK_ROOT=$OPTARG
|
||||
;;
|
||||
\? )
|
||||
echo "Usage: run_training.sh -r <Task root directory>"
|
||||
;;
|
||||
: )
|
||||
echo "Invalid option: $OPTARG requires an argument" 1>&2
|
||||
;;
|
||||
esac
|
||||
done
|
||||
shift $((OPTIND -1))
|
||||
|
||||
# Get this script's directory.
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
|
||||
|
||||
DATA_ROOT=${TASK_ROOT}/data/
|
||||
|
||||
|
||||
python "${SCRIPT_DIR}"/generate_validation_splits.py \
|
||||
--output_dir="${DATA_ROOT}/k_fold_splits"
|
||||
|
||||
mkdir -p ${DATA_ROOT}/preprocessed/
|
||||
python "${SCRIPT_DIR}"/generate_conformer_features.py \
|
||||
--splits="test valid train" \
|
||||
--num_parallel_procs=32 \
|
||||
--output_file="${DATA_ROOT}/preprocessed/smile_to_conformer.pkl"
|
||||
Executable
+107
@@ -0,0 +1,107 @@
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Generates test predictions from trained model checkpoints.
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
while getopts ":r:" opt; do
|
||||
case ${opt} in
|
||||
r )
|
||||
TASK_ROOT=$OPTARG
|
||||
;;
|
||||
\? )
|
||||
echo "Usage: run_pretrained_eval.sh -r <Task root directory>"
|
||||
;;
|
||||
: )
|
||||
echo "Invalid option: $OPTARG requires an argument" 1>&2
|
||||
;;
|
||||
esac
|
||||
done
|
||||
shift $((OPTIND -1))
|
||||
|
||||
# Get this script's directory.
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
|
||||
SPLIT="test"
|
||||
|
||||
DATA_ROOT=${TASK_ROOT}/data/ # For valid k-fold splits.
|
||||
MODELS_ROOT=${TASK_ROOT}/models/
|
||||
CACHED_CONFORMERS_DIR=${TASK_ROOT}/online_conformers/$SPLIT
|
||||
CACHED_CONFORMERS_FILE=${CACHED_CONFORMERS_DIR}/smiles_to_conformers.pkl
|
||||
OUTPUT_DIR=${TASK_ROOT}/predictions/
|
||||
|
||||
# First generate and cache the conformer feature data for the test split.
|
||||
mkdir -p ${CACHED_CONFORMERS_DIR}
|
||||
time python "${SCRIPT_DIR}"/generate_conformer_features.py \
|
||||
--splits=$SPLIT \
|
||||
--num_parallel_procs=32 \
|
||||
--output_file=${CACHED_CONFORMERS_FILE}
|
||||
|
||||
seed=42
|
||||
# Share GPU between two runs by disabling XLA memory pre-allocation.
|
||||
export XLA_PYTHON_CLIENT_PREALLOCATE=false
|
||||
|
||||
for seed_group in 0 1; do
|
||||
for k in `seq 0 9`; do
|
||||
# Conformer
|
||||
predictions_dir="${OUTPUT_DIR}/predictions/${SPLIT}/conformer/k${k}_seed${seed}"
|
||||
time python "${SCRIPT_DIR}"/experiment.py \
|
||||
--jaxline_mode="eval" \
|
||||
--config="${SCRIPT_DIR}"/config.py \
|
||||
--config.experiment_kwargs.config.evaluation.split=$SPLIT \
|
||||
--config.restore_path=${MODELS_ROOT}/conformer/k${k}_seed${seed} \
|
||||
--config.experiment_kwargs.config.predictions_dir=${predictions_dir} \
|
||||
--config.experiment_kwargs.config.dataset_config.data_root=${DATA_ROOT} \
|
||||
--config.experiment_kwargs.config.dataset_config.k_fold_split_id=$k \
|
||||
--config.experiment_kwargs.config.dataset_config.cached_conformers_file=${CACHED_CONFORMERS_FILE} \
|
||||
--config.experiment_kwargs.config.dataset_config.filter_in_or_out_samples_with_nans_in_conformers="out" \
|
||||
--config.experiment_kwargs.config.model.latent_size=256 \
|
||||
--config.experiment_kwargs.config.model.mlp_hidden_size=1024 \
|
||||
--config.experiment_kwargs.config.model.num_message_passing_steps=32 \
|
||||
--config.one_off_evaluate &
|
||||
|
||||
|
||||
# Non-Conformer
|
||||
predictions_dir="${OUTPUT_DIR}/predictions/${SPLIT}/non_conformer/k${k}_seed${seed}"
|
||||
time python "${SCRIPT_DIR}"/experiment.py \
|
||||
--jaxline_mode="eval" \
|
||||
--config="${SCRIPT_DIR}"/config.py \
|
||||
--config.experiment_kwargs.config.evaluation.split=$SPLIT \
|
||||
--config.restore_path=${MODELS_ROOT}/non_conformer/k${k}_seed${seed} \
|
||||
--config.experiment_kwargs.config.predictions_dir=${predictions_dir} \
|
||||
--config.experiment_kwargs.config.dataset_config.data_root=${DATA_ROOT} \
|
||||
--config.experiment_kwargs.config.dataset_config.k_fold_split_id=$k \
|
||||
--config.experiment_kwargs.config.dataset_config.cached_conformers_file=${CACHED_CONFORMERS_FILE} \
|
||||
--config.experiment_kwargs.config.dataset_config.filter_in_or_out_samples_with_nans_in_conformers="in" \
|
||||
--config.experiment_kwargs.config.model.num_message_passing_steps=50 \
|
||||
--config.experiment_kwargs.config.model.add_relative_distance=false \
|
||||
--config.experiment_kwargs.config.model.add_relative_displacement=false \
|
||||
--config.one_off_evaluate &
|
||||
|
||||
wait
|
||||
|
||||
((seed=seed+1))
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
# Ensemble predictions.
|
||||
time python "${SCRIPT_DIR}"/ensemble_predictions.py \
|
||||
--seed_start=42 \
|
||||
--split=$SPLIT \
|
||||
--conformer_path="${OUTPUT_DIR}/predictions/${SPLIT}/conformer" \
|
||||
--non_conformer_path="${OUTPUT_DIR}/predictions/${SPLIT}/non_conformer" \
|
||||
--output_path="${OUTPUT_DIR}/ensembled_predictions/"
|
||||
Executable
+134
@@ -0,0 +1,134 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 DeepMind Technologies Limited.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
while getopts ":r:" opt; do
|
||||
case ${opt} in
|
||||
r )
|
||||
TASK_ROOT=$OPTARG
|
||||
;;
|
||||
\? )
|
||||
echo "Usage: run_training.sh -r <Task root directory>"
|
||||
;;
|
||||
: )
|
||||
echo "Invalid option: $OPTARG requires an argument" 1>&2
|
||||
;;
|
||||
esac
|
||||
done
|
||||
shift $((OPTIND -1))
|
||||
|
||||
# Get this script's directory.
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
|
||||
|
||||
echo "
|
||||
These scripts are provided for illustrative purposes. It is not practical for
|
||||
actual training since it only uses a single machine, and likely requires
|
||||
reducing the batch size and/or model size to fit on a single GPU.
|
||||
|
||||
For the actual submission we used training distributed in different ways:
|
||||
* We used 4x Cloud TPU v4s to implement batch parallelism, so batch
|
||||
size is effectively 8 times larger than the values in the config.
|
||||
* We ran the online early-stopping evaluator on a separate machine with an
|
||||
NVIDIA V100 GPU.
|
||||
* We run identical replicas of all of the above for each of the 40 models
|
||||
trained.
|
||||
|
||||
Using those mechanisms, training runs at ~4 and 6 steps per second, reaching
|
||||
500k steps in about 1.5 days.
|
||||
"
|
||||
|
||||
echo "
|
||||
Pre-requisites (See README):
|
||||
* cd into directory containing the 'ogb_lsc' folder.
|
||||
* Python dependencies have been installed.
|
||||
* pre-processed data is available at the hardcoded paths
|
||||
"
|
||||
|
||||
read -p "Press enter to continue"
|
||||
|
||||
DATA_ROOT=${TASK_ROOT}/data/
|
||||
CHECKPOINT_DIR=${TASK_ROOT}/checkpoints/
|
||||
CACHED_CONFORMERS_FILE="${DATA_ROOT}/preprocessed/smile_to_conformer.pkl"
|
||||
|
||||
# We run two seeds for each model of the k=10 k-fold.
|
||||
BASE_SEED=42
|
||||
for SEED_OFFSET in 0 10; do
|
||||
for K_FOLD_INDEX in {0..9}; do
|
||||
MODEL_SEED=`expr ${BASE_SEED} + ${SEED_OFFSET} + ${K_FOLD_INDEX}`
|
||||
|
||||
echo "Running k=${K_FOLD_INDEX} with init seed ${MODEL_SEED}"
|
||||
|
||||
# This runs training (each model is trained on train split + 90% of the
|
||||
# validation split) with early stopping, storing both "latest" model and
|
||||
# "best" early-stopped model at `--config.checkpoint_dir`.
|
||||
|
||||
# Models are early stopped based on accuracy of on 10% of the validation data
|
||||
# (each K_FOLD_INDEX leaves a different 10% of data out from training) left
|
||||
# out from training.
|
||||
|
||||
# Models are stored at the end of training. Intermediate models can also be
|
||||
# stored while training by sending a SIGINT signal (Ctrl+C) which will not
|
||||
# interrupt the training.
|
||||
|
||||
# It is possible to interrupt training using (Ctrl+\) and then continue
|
||||
# passing the corresponding `--config.restore_path=${RESTORE_PATH}` which is
|
||||
# stored when interrupting.
|
||||
|
||||
# Conformer.
|
||||
SUFFIX="conformer/k${K_FOLD_INDEX}_seed${MODEL_SEED}"
|
||||
echo "Running ${SUFFIX}"
|
||||
python "${SCRIPT_DIR}"/experiment.py \
|
||||
--jaxline_mode="train_eval_multithreaded" \
|
||||
--config="${SCRIPT_DIR}"/config.py \
|
||||
--config.random_seed=${MODEL_SEED} \
|
||||
--config.checkpoint_dir=${CHECKPOINT_DIR}/${SUFFIX} \
|
||||
--config.experiment_kwargs.config.dataset_config.data_root=${DATA_ROOT} \
|
||||
--config.experiment_kwargs.config.dataset_config.k_fold_split_id=${K_FOLD_INDEX} \
|
||||
--config.experiment_kwargs.config.dataset_config.num_k_fold_splits=10 \
|
||||
--config.experiment_kwargs.config.dataset_config.cached_conformers_file=${CACHED_CONFORMERS_FILE} \
|
||||
--config.experiment_kwargs.config.model.latent_size=256 \
|
||||
--config.experiment_kwargs.config.model.mlp_hidden_size=1024 \
|
||||
--config.experiment_kwargs.config.model.num_message_passing_steps=32 \
|
||||
--config.experiment_kwargs.config.evaluation.split="valid"
|
||||
|
||||
|
||||
# Non-Conformer.
|
||||
SUFFIX="non_conformer/k${K_FOLD_INDEX}_seed${MODEL_SEED}"
|
||||
echo "Running ${SUFFIX}"
|
||||
python "${SCRIPT_DIR}"/experiment.py \
|
||||
--jaxline_mode="train_eval_multithreaded" \
|
||||
--config="${SCRIPT_DIR}"/config.py \
|
||||
--config.random_seed=${MODEL_SEED} \
|
||||
--config.checkpoint_dir=${CHECKPOINT_DIR}/${SUFFIX} \
|
||||
--config.experiment_kwargs.config.dataset_config.data_root=${DATA_ROOT} \
|
||||
--config.experiment_kwargs.config.dataset_config.k_fold_split_id=${K_FOLD_INDEX} \
|
||||
--config.experiment_kwargs.config.dataset_config.num_k_fold_splits=10 \
|
||||
--config.experiment_kwargs.config.dataset_config.cached_conformers_file=${CACHED_CONFORMERS_FILE} \
|
||||
--config.experiment_kwargs.config.model.num_message_passing_steps=50 \
|
||||
--config.experiment_kwargs.config.model.add_relative_distance=false \
|
||||
--config.experiment_kwargs.config.model.add_relative_displacement=false \
|
||||
--config.experiment_kwargs.config.evaluation.split="valid"
|
||||
|
||||
done
|
||||
done
|
||||
|
||||
# Each of the 40 (two for each value of k, and both conformer and non conformer)
|
||||
# jobs, it will generate paths of the form with the early-stopped models:
|
||||
# RESTORE_PATH=${--config.checkpoint_dir}/models/best/step_${STEP}_${TIMESTAMP}
|
||||
# These can then be used as the "RESTORE_PATHS" for `./run_pretrained_eval.sh`.
|
||||
|
||||
echo "Done"
|
||||
Reference in New Issue
Block a user