diff --git a/README.md b/README.md index 983b952..d0086f3 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ https://deepmind.com/research/publications/ ## Projects +* [Open Graph Benchmark - Large-Scale Challenge (OGB-LSC)](ogb_lsc) * [Synthetic Returns for Long-Term Credit Assignment](synthetic_returns) * [A Deep Learning Approach for Characterizing Major Galaxy Mergers](galaxy_mergers) * [Better, Faster Fermionic Neural Networks](kfac_ferminet_alpha) (KFAC implementation) @@ -73,7 +74,6 @@ https://deepmind.com/research/publications/ - ## Disclaimer *This is not an official Google product.* diff --git a/ogb_lsc/README.md b/ogb_lsc/README.md new file mode 100644 index 0000000..f428648 --- /dev/null +++ b/ogb_lsc/README.md @@ -0,0 +1,29 @@ +# DeepMind entry for PCQM4M-LSC + +This repository contains DeepMind's entry to the [PCWM4M-LSC](https://ogb.stanford.edu/kddcup2021/pcqm4m/) (quantum chemistry) and +[MAG240M-LSC](https://ogb.stanford.edu/kddcup2021/mag240m/) (academic graph) +tracks of the [OGB Large-Scale Challenge](https://ogb.stanford.edu/kddcup2021/) +(OGB-LSC). + +For full details regarding this entry, please see our [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf). + +## Code structure + +* `pcq/`: Scripts for training, evaluating on the [PCQ dataset](https://ogb.stanford.edu/docs/graphprop/). +* `mag/`: Scripts for training, evaluating on the [MAG dataset](https://ogb.stanford.edu/docs/nodeprop/). + +## Citation + +To cite this work: + +```latex +@article{deepmind2021ogb, + author = {Ravichandra Addanki and Peter Battaglia and David Budden and Andreea + Deac and Jonathan Godwin and Thomas Keck and Wai Lok Sibon Li and Alvaro + Sanchez-Gonzalez and Jacklynn Stott and Shantanu Thakoor and Petar + Veli\v{c}kovi\'{c}}, + title = {Large-scale graph representation learning with very deep GNNs and + self-supervision}, + year = {2021}, +} +``` diff --git a/ogb_lsc/mag/README.md b/ogb_lsc/mag/README.md new file mode 100644 index 0000000..210ecc7 --- /dev/null +++ b/ogb_lsc/mag/README.md @@ -0,0 +1,128 @@ +# DeepMind entry for MAG240M-LSC + +This repository contains DeepMind's entry to the [MAG240M-LSC](https://ogb.stanford.edu/kddcup2021/mag240m/) (academic graph) track of the +[OGB Large-Scale Challenge](https://ogb.stanford.edu/kddcup2021/) (OGB-LSC). + +For full details regarding this entry, please see our [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf). + +## DeepMind MAG Team ("Academic") + +(in alphabetical order) + +- Ravichandra Addanki +- Peter Battaglia +- David Budden +- Andreea Deac +- Jonathan Godwin +- Thomas Keck +- Alvaro Sanchez-Gonzalez +- Jacklynn Stott +- Shantanu Thakoor +- Petar Veličković + +## Performance + +Our final test set performance was achieved by pooling an ensemble of 10 folds. +See [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf) for details. + +Each model was trained for < 72 hours using 4x Google Cloud TPUv4 and 1x AMD +EPYC 7B12 64-core CPU @2.25GHz. + +Inference takes < 12 hours on 4x NVIDIA V100 16GB GPU and 1x Intel Xeon Gold +6148 20-core CPU @2.40GHz. + +# Running our model + +## Setup + +You can set up Python virtual environment (you might need to install the +`python3-venv` package first) with all needed dependencies inside the forked +`deepmind_research` repository using: + +```bash +python3 -m venv /tmp/mag_venv +source /tmp/mag_venv/bin/activate +pip3 install --upgrade pip setuptools wheel +pip3 install -r ogb_lsc/mag/requirements.txt +``` + +## Download and pre-process data + + +**1. Download the dataset using the contest toolkit ([here](https://ogb.stanford.edu/kddcup2021/mag240m/#dataset)) to a local directory +`ROOT`.** + +**2. Run this script to reorganize the data into a flat directory structure with +transparent names** + +```bash +/bin/bash organize_data.sh -r ROOT +``` + +Once this completes, a new directory `ROOT/mag240m_kddcup2021/raw` will be +created, with contents: + +- `node_feat.npy` +- `node_label.npy` +- `node_year.npy` +- `author_affiliated_with_institution_edges.npy` +- `author_writes_paper_edges.npy` +- `paper_cites_paper_edges.npy` +- `train_idx.npy` +- `valid_idx.npy` +- `test_idx.npy` + +We refer to this as the "raw" data. + +**3. To run the preprocessing code** +```bash +/bin/bash run_preprocessing.sh -r ROOT +``` + +The pre-processing is very time- and memory-consuming, and should only be run +to verify the full pipeline. You can download the pre-processed data using the +following script, for use in training and evaluating models: + +```bash +python3 download_mag.py --task_root=${HOME}/mag --payload="data" +``` + + +## Reproducing our final results + +We have provided pre-trained weights of our final submission for convenience. +They can be downloaded with: +``` +python3 download_mag.py --task_root=${HOME}/mag --payload="models" +``` +Then to reproduce our final results, please run `bash run_pretrain_eval.sh`. + + +## Retraining our model + +Disclaimer: This script is provided for illustrative purposes. It is not +practical for actual training since it only uses a single machine, and likely +requires reducing the batch size and/or model size to fit on a single GPU. + +If you still want to train a model, please run `run_training.sh`. To simply +validate that the code is running correctly on your hardware setup, consider +setting `debug=True` in `config.py`, which trains a smaller model. + + +# Citation + +To cite this work (together with our PCQM4M-LSC entry): + +```latex +@article{deepmind2021ogb, + author = {Ravichandra Addanki and Peter Battaglia and David Budden and Andreea + Deac and Jonathan Godwin and Thomas Keck and Wai Lok Sibon Li and Alvaro + Sanchez-Gonzalez and Jacklynn Stott and Shantanu Thakoor and Petar + Veli\v{c}kovi\'{c}}, + title = {Large-scale graph representation learning with very deep GNNs and + self-supervision}, + year = {2021}, +} +``` + +Our technical report can be found [here](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf). diff --git a/ogb_lsc/mag/batching_utils.py b/ogb_lsc/mag/batching_utils.py new file mode 100644 index 0000000..8839205 --- /dev/null +++ b/ogb_lsc/mag/batching_utils.py @@ -0,0 +1,138 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic batching utilities.""" + +from typing import Generator, Iterable, Iterator, Sequence, Tuple + +import jax.tree_util as tree +import jraph +import numpy as np + +_NUMBER_FIELDS = ("n_node", "n_edge", "n_graph") + + +def dynamically_batch(graphs_tuple_iterator: Iterator[jraph.GraphsTuple], + n_node: int, n_edge: int, + n_graph: int) -> Generator[jraph.GraphsTuple, None, None]: + """Dynamically batches trees with `jraph.GraphsTuples` to `graph_batch_size`. + + Elements of the `graphs_tuple_iterator` will be incrementally added to a batch + until the limits defined by `n_node`, `n_edge` and `n_graph` are reached. This + means each element yielded by this generator + + For situations where you have variable sized data, it"s useful to be able to + have variable sized batches. This is especially the case if you have a loss + defined on the variable shaped element (for example, nodes in a graph). + + Args: + graphs_tuple_iterator: An iterator of `jraph.GraphsTuples`. + n_node: The maximum number of nodes in a batch. + n_edge: The maximum number of edges in a batch. + n_graph: The maximum number of graphs in a batch. + + Yields: + A `jraph.GraphsTuple` batch of graphs. + + Raises: + ValueError: if the number of graphs is < 2. + RuntimeError: if the `graphs_tuple_iterator` contains elements which are not + `jraph.GraphsTuple`s. + RuntimeError: if a graph is found which is larger than the batch size. + """ + if n_graph < 2: + raise ValueError("The number of graphs in a batch size must be greater or " + f"equal to `2` for padding with graphs, got {n_graph}.") + valid_batch_size = (n_node - 1, n_edge, n_graph - 1) + accumulated_graphs = [] + num_accumulated_nodes = 0 + num_accumulated_edges = 0 + num_accumulated_graphs = 0 + for element in graphs_tuple_iterator: + element_nodes, element_edges, element_graphs = _get_graph_size(element) + if _is_over_batch_size(element, valid_batch_size): + graph_size = element_nodes, element_edges, element_graphs + graph_size = {k: v for k, v in zip(_NUMBER_FIELDS, graph_size)} + batch_size = {k: v for k, v in zip(_NUMBER_FIELDS, valid_batch_size)} + raise RuntimeError("Found graph bigger than batch size. Valid Batch " + f"Size: {batch_size}, Graph Size: {graph_size}") + + if not accumulated_graphs: + # If this is the first element of the batch, set it and continue. + accumulated_graphs = [element] + num_accumulated_nodes = element_nodes + num_accumulated_edges = element_edges + num_accumulated_graphs = element_graphs + continue + else: + # Otherwise check if there is space for the graph in the batch: + if ((num_accumulated_graphs + element_graphs > n_graph - 1) or + (num_accumulated_nodes + element_nodes > n_node - 1) or + (num_accumulated_edges + element_edges > n_edge)): + # If there is, add it to the batch + batched_graph = _batch_np(accumulated_graphs) + yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph) + accumulated_graphs = [element] + num_accumulated_nodes = element_nodes + num_accumulated_edges = element_edges + num_accumulated_graphs = element_graphs + else: + # Otherwise, return the old batch and start a new batch. + accumulated_graphs.append(element) + num_accumulated_nodes += element_nodes + num_accumulated_edges += element_edges + num_accumulated_graphs += element_graphs + + # We may still have data in batched graph. + if accumulated_graphs: + batched_graph = _batch_np(accumulated_graphs) + yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph) + + +def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple: + # Calculates offsets for sender and receiver arrays, caused by concatenating + # the nodes arrays. + offsets = np.cumsum(np.array([0] + [np.sum(g.n_node) for g in graphs[:-1]])) + + def _map_concat(nests): + concat = lambda *args: np.concatenate(args) + return tree.tree_multimap(concat, *nests) + + return jraph.GraphsTuple( + n_node=np.concatenate([g.n_node for g in graphs]), + n_edge=np.concatenate([g.n_edge for g in graphs]), + nodes=_map_concat([g.nodes for g in graphs]), + edges=_map_concat([g.edges for g in graphs]), + globals=_map_concat([g.globals for g in graphs]), + senders=np.concatenate([g.senders + o for g, o in zip(graphs, offsets)]), + receivers=np.concatenate( + [g.receivers + o for g, o in zip(graphs, offsets)])) + + +def _get_graph_size(graph: jraph.GraphsTuple) -> Tuple[int, int, int]: + n_node = np.sum(graph.n_node) + n_edge = len(graph.senders) + n_graph = len(graph.n_node) + return n_node, n_edge, n_graph + + +def _is_over_batch_size( + graph: jraph.GraphsTuple, + graph_batch_size: Iterable[int], +) -> bool: + graph_size = _get_graph_size(graph) + return any([x > y for x, y in zip(graph_size, graph_batch_size)]) + + + diff --git a/ogb_lsc/mag/config.py b/ogb_lsc/mag/config.py new file mode 100644 index 0000000..059d288 --- /dev/null +++ b/ogb_lsc/mag/config.py @@ -0,0 +1,117 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experiment config for MAG240M-LSC entry.""" + +from jaxline import base_config +from ml_collections import config_dict + + +def get_config(debug: bool = False) -> config_dict.ConfigDict: + """Get Jaxline experiment config.""" + config = base_config.get_base_config() + config.random_seed = 42 + # E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below) + config.restore_path = config_dict.placeholder(str) + config.experiment_kwargs = config_dict.ConfigDict( + dict( + config=dict( + debug=debug, + predictions_dir=config_dict.placeholder(str), + # 5 for model selection and early stopping, 50 for final eval. + num_eval_iterations_to_ensemble=5, + dataset_kwargs=dict( + data_root='/data/', + online_subsampling_kwargs=dict( + max_nb_neighbours_per_type=[ + [[40, 20, 0, 40], [0, 0, 0, 0], [0, 0, 0, 0]], + [[40, 20, 0, 40], [40, 0, 10, 0], [0, 0, 0, 0]], + ], + remove_future_nodes=True, + deduplicate_nodes=True, + ), + ratio_unlabeled_data_to_labeled_data=10.0, + k_fold_split_id=config_dict.placeholder(int), + use_all_labels_when_not_training=False, + use_dummy_adjacencies=debug, + ), + optimizer=dict( + name='adamw', + kwargs=dict(weight_decay=1e-5, b1=0.9, b2=0.999), + learning_rate_schedule=dict( + use_schedule=True, + base_learning_rate=1e-2, + warmup_steps=50000, + total_steps=config.get_ref('training_steps'), + ), + ), + model_config=dict( + mlp_hidden_sizes=[32] if debug else [512], + latent_size=32 if debug else 256, + num_message_passing_steps=2 if debug else 4, + activation='relu', + dropout_rate=0.3, + dropedge_rate=0.25, + disable_edge_updates=True, + use_sent_edges=True, + normalization_type='layer_norm', + aggregation_function='sum', + ), + training=dict( + loss_config=dict( + bgrl_loss_config=dict( + stop_gradient_for_supervised_loss=False, + bgrl_loss_scale=1.0, + symmetrize=True, + first_graph_corruption_config=dict( + feature_drop_prob=0.4, + edge_drop_prob=0.2, + ), + second_graph_corruption_config=dict( + feature_drop_prob=0.4, + edge_drop_prob=0.2, + ), + ), + ), + # GPU memory may require reducing the `256`s below to `48`. + dynamic_batch_size_config=dict( + n_node=256 if debug else 340 * 256, + n_edge=512 if debug else 720 * 256, + n_graph=4 if debug else 256, + ), + ), + eval=dict( + split='valid', + ema_annealing_schedule=dict( + use_schedule=True, + base_rate=0.999, + total_steps=config.get_ref('training_steps')), + dynamic_batch_size_config=dict( + n_node=256 if debug else 340 * 128, + n_edge=512 if debug else 720 * 128, + n_graph=4 if debug else 128, + ), + )))) + + ## Training loop config. + config.training_steps = 500000 + config.checkpoint_dir = '/tmp/checkpoint/mag/' + config.train_checkpoint_all_hosts = False + config.log_train_data_interval = 10 + config.log_tensors_interval = 10 + config.save_checkpoint_interval = 30 + config.best_model_eval_metric = 'accuracy' + config.best_model_eval_metric_higher_is_better = True + + return config diff --git a/ogb_lsc/mag/csr_builder.py b/ogb_lsc/mag/csr_builder.py new file mode 100644 index 0000000..ddc858f --- /dev/null +++ b/ogb_lsc/mag/csr_builder.py @@ -0,0 +1,136 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Builds CSR matrices which store the MAG graphs.""" + +import pathlib + +from absl import app +from absl import flags +from absl import logging +import numpy as np +import scipy.sparse + +# pylint: disable=g-bad-import-order +import data_utils + +Path = pathlib.Path + + +FLAGS = flags.FLAGS + +_DATA_FILES_AND_PARAMETERS = { + 'author_affiliated_with_institution_edges.npy': { + 'content_names': ('author', 'institution'), + 'use_boolean': False + }, + 'author_writes_paper_edges.npy': { + 'content_names': ('author', 'paper'), + 'use_boolean': False + }, + 'paper_cites_paper_edges.npy': { + 'content_names': ('paper', 'paper'), + 'use_boolean': True + }, +} + +flags.DEFINE_string('data_root', None, 'Data root directory') +flags.DEFINE_boolean('skip_existing', True, 'Skips existing CSR files') + +flags.mark_flags_as_required(['data_root']) + + +def _read_edge_data(path): + try: + return np.load(path, mmap_mode='r') + except FileNotFoundError: + # If the file path can't be found by np.load, use the file handle w/o mmap. + with path.open('rb') as fid: + return np.load(fid) + + +def _build_coo(edges_data, use_boolean=False): + if use_boolean: + mat_coo = scipy.sparse.coo_matrix( + (np.ones_like(edges_data[1, :], + dtype=bool), (edges_data[0, :], edges_data[1, :]))) + else: + mat_coo = scipy.sparse.coo_matrix( + (edges_data[1, :], (edges_data[0, :], edges_data[1, :]))) + return mat_coo + + +def _get_output_paths(directory, content_names, use_boolean): + boolean_str = '_b' if use_boolean else '' + transpose_str = '_t' if len(set(content_names)) == 1 else '' + output_prefix = '_'.join(content_names) + output_prefix_t = '_'.join(content_names[::-1]) + output_filename = f'{output_prefix}{boolean_str}.npz' + output_filename_t = f'{output_prefix_t}{boolean_str}{transpose_str}.npz' + output_path = directory / output_filename + output_path_t = directory / output_filename_t + return output_path, output_path_t + + +def _write_csr(path, csr): + path.parent.mkdir(parents=True, exist_ok=True) + with path.open('wb') as fid: + scipy.sparse.save_npz(fid, csr) + + +def main(argv): + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + raw_data_dir = Path(FLAGS.data_root) / data_utils.RAW_DIR + preprocessed_dir = Path(FLAGS.data_root) / data_utils.PREPROCESSED_DIR + + for input_filename, parameters in _DATA_FILES_AND_PARAMETERS.items(): + input_path = raw_data_dir / input_filename + output_path, output_path_t = _get_output_paths(preprocessed_dir, + **parameters) + if FLAGS.skip_existing and output_path.exists() and output_path_t.exists(): + # If both files exist, skip. When only one exists, that's handled below. + logging.info( + '%s and %s exist: skipping. Use flag `--skip_existing=False`' + 'to force overwrite existing.', output_path, output_path_t) + continue + logging.info('Reading edge data from: %s', input_path) + edge_data = _read_edge_data(input_path) + logging.info('Building CSR matrices') + mat_coo = _build_coo(edge_data, use_boolean=parameters['use_boolean']) + # Convert matrices to CSR and write to disk. + if not FLAGS.skip_existing or not output_path.exists(): + logging.info('Writing CSR matrix to: %s', output_path) + mat_csr = mat_coo.tocsr() + _write_csr(output_path, mat_csr) + del mat_csr # Free up memory asap. + else: + logging.info( + '%s exists: skipping. Use flag `--skip_existing=False`' + 'to force overwrite existing.', output_path) + if not FLAGS.skip_existing or not output_path_t.exists(): + logging.info('Writing (transposed) CSR matrix to: %s', output_path_t) + mat_csr_t = mat_coo.transpose().tocsr() + _write_csr(output_path_t, mat_csr_t) + del mat_csr_t # Free up memory asap. + else: + logging.info( + '%s exists: skipping. Use flag `--skip_existing=False`' + 'to force overwrite existing.', output_path_t) + del mat_coo # Free up memory asap. + + +if __name__ == '__main__': + app.run(main) diff --git a/ogb_lsc/mag/data_utils.py b/ogb_lsc/mag/data_utils.py new file mode 100644 index 0000000..639f317 --- /dev/null +++ b/ogb_lsc/mag/data_utils.py @@ -0,0 +1,491 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset utilities.""" + +import functools +import pathlib +from typing import Dict, Tuple + +from absl import logging +from graph_nets import graphs as tf_graphs +from graph_nets import utils_tf +import numpy as np +import scipy.sparse as sp +import tensorflow as tf +import tqdm + +# pylint: disable=g-bad-import-order +import sub_sampler + +Path = pathlib.Path + +NUM_PAPERS = 121751666 +NUM_AUTHORS = 122383112 +NUM_INSTITUTIONS = 25721 +EMBEDDING_SIZE = 768 +NUM_CLASSES = 153 + +NUM_NODES = NUM_PAPERS + NUM_AUTHORS + NUM_INSTITUTIONS +NUM_EDGES = 1_728_364_232 +assert NUM_NODES == 244_160_499 + +NUM_K_FOLD_SPLITS = 10 + + +OFFSETS = { + "paper": 0, + "author": NUM_PAPERS, + "institution": NUM_PAPERS + NUM_AUTHORS, +} + + +SIZES = { + "paper": NUM_PAPERS, + "author": NUM_AUTHORS, + "institution": NUM_INSTITUTIONS +} + +RAW_DIR = Path("raw") +PREPROCESSED_DIR = Path("preprocessed") + +RAW_NODE_FEATURES_FILENAME = RAW_DIR / "node_feat.npy" +RAW_NODE_LABELS_FILENAME = RAW_DIR / "node_label.npy" +RAW_NODE_YEAR_FILENAME = RAW_DIR / "node_year.npy" + +TRAIN_INDEX_FILENAME = RAW_DIR / "train_idx.npy" +VALID_INDEX_FILENAME = RAW_DIR / "train_idx.npy" +TEST_INDEX_FILENAME = RAW_DIR / "train_idx.npy" + +EDGES_PAPER_PAPER_B = PREPROCESSED_DIR / "paper_paper_b.npz" +EDGES_PAPER_PAPER_B_T = PREPROCESSED_DIR / "paper_paper_b_t.npz" +EDGES_AUTHOR_INSTITUTION = PREPROCESSED_DIR / "author_institution.npz" +EDGES_INSTITUTION_AUTHOR = PREPROCESSED_DIR / "institution_author.npz" +EDGES_AUTHOR_PAPER = PREPROCESSED_DIR / "author_paper.npz" +EDGES_PAPER_AUTHOR = PREPROCESSED_DIR / "paper_author.npz" + +PCA_PAPER_FEATURES_FILENAME = PREPROCESSED_DIR / "paper_feat_pca_129.npy" +PCA_AUTHOR_FEATURES_FILENAME = ( + PREPROCESSED_DIR / "author_feat_from_paper_feat_pca_129.npy") +PCA_INSTITUTION_FEATURES_FILENAME = ( + PREPROCESSED_DIR / "institution_feat_from_paper_feat_pca_129.npy") +PCA_MERGED_FEATURES_FILENAME = ( + PREPROCESSED_DIR / "merged_feat_from_paper_feat_pca_129.npy") + +NEIGHBOR_INDICES_FILENAME = PREPROCESSED_DIR / "neighbor_indices.npy" +NEIGHBOR_DISTANCES_FILENAME = PREPROCESSED_DIR / "neighbor_distances.npy" + +FUSED_NODE_LABELS_FILENAME = PREPROCESSED_DIR / "fused_node_labels.npy" +FUSED_PAPER_EDGES_FILENAME = PREPROCESSED_DIR / "fused_paper_edges.npz" +FUSED_PAPER_EDGES_T_FILENAME = PREPROCESSED_DIR / "fused_paper_edges_t.npz" + +K_FOLD_SPLITS_DIR = Path("k_fold_splits") + + +def get_raw_directory(data_root): + return Path(data_root) / "raw" + + +def get_preprocessed_directory(data_root): + return Path(data_root) / "preprocessed" + + +def _log_path_decorator(fn): + def _decorated_fn(path, **kwargs): + logging.info("Loading %s", path) + output = fn(path, **kwargs) + logging.info("Finish loading %s", path) + return output + return _decorated_fn + + +@_log_path_decorator +def load_csr(path, debug=False): + if debug: + # Dummy matrix for debugging. + return sp.csr_matrix(np.zeros([10, 10])) + return sp.load_npz(str(path)) + + +@_log_path_decorator +def load_npy(path): + return np.load(str(path)) + + +@functools.lru_cache() +def get_arrays(data_root="/data/", + use_fused_node_labels=True, + use_fused_node_adjacencies=True, + return_pca_embeddings=True, + k_fold_split_id=None, + return_adjacencies=True, + use_dummy_adjacencies=False): + """Returns all arrays needed for training.""" + logging.info("Starting to get files") + + data_root = Path(data_root) + + array_dict = {} + array_dict["paper_year"] = load_npy(data_root / RAW_NODE_YEAR_FILENAME) + + if k_fold_split_id is None: + train_indices = load_npy(data_root / TRAIN_INDEX_FILENAME) + valid_indices = load_npy(data_root / VALID_INDEX_FILENAME) + else: + train_indices, valid_indices = get_train_and_valid_idx_for_split( + k_fold_split_id, num_splits=NUM_K_FOLD_SPLITS, + root_path=data_root / K_FOLD_SPLITS_DIR) + + array_dict["train_indices"] = train_indices + array_dict["valid_indices"] = valid_indices + array_dict["test_indices"] = load_npy(data_root / TEST_INDEX_FILENAME) + + if use_fused_node_labels: + array_dict["paper_label"] = load_npy(data_root / FUSED_NODE_LABELS_FILENAME) + else: + array_dict["paper_label"] = load_npy(data_root / RAW_NODE_LABELS_FILENAME) + + if return_adjacencies: + logging.info("Starting to get adjacencies.") + if use_fused_node_adjacencies: + paper_paper_index = load_csr( + data_root / FUSED_PAPER_EDGES_FILENAME, debug=use_dummy_adjacencies) + paper_paper_index_t = load_csr( + data_root / FUSED_PAPER_EDGES_T_FILENAME, debug=use_dummy_adjacencies) + else: + paper_paper_index = load_csr( + data_root / EDGES_PAPER_PAPER_B, debug=use_dummy_adjacencies) + paper_paper_index_t = load_csr( + data_root / EDGES_PAPER_PAPER_B_T, debug=use_dummy_adjacencies) + array_dict.update( + dict( + author_institution_index=load_csr( + data_root / EDGES_AUTHOR_INSTITUTION, + debug=use_dummy_adjacencies), + institution_author_index=load_csr( + data_root / EDGES_INSTITUTION_AUTHOR, + debug=use_dummy_adjacencies), + author_paper_index=load_csr( + data_root / EDGES_AUTHOR_PAPER, debug=use_dummy_adjacencies), + paper_author_index=load_csr( + data_root / EDGES_PAPER_AUTHOR, debug=use_dummy_adjacencies), + paper_paper_index=paper_paper_index, + paper_paper_index_t=paper_paper_index_t, + )) + + if return_pca_embeddings: + array_dict["bert_pca_129"] = np.load( + data_root / PCA_MERGED_FEATURES_FILENAME, mmap_mode="r") + assert array_dict["bert_pca_129"].shape == (NUM_NODES, 129) + + logging.info("Finish getting files") + + # pytype: disable=attribute-error + assert array_dict["paper_year"].shape[0] == NUM_PAPERS + assert array_dict["paper_label"].shape[0] == NUM_PAPERS + + if return_adjacencies and not use_dummy_adjacencies: + array_dict = _fix_adjacency_shapes(array_dict) + + assert array_dict["paper_author_index"].shape == (NUM_PAPERS, NUM_AUTHORS) + assert array_dict["author_paper_index"].shape == (NUM_AUTHORS, NUM_PAPERS) + assert array_dict["paper_paper_index"].shape == (NUM_PAPERS, NUM_PAPERS) + assert array_dict["paper_paper_index_t"].shape == (NUM_PAPERS, NUM_PAPERS) + assert array_dict["institution_author_index"].shape == ( + NUM_INSTITUTIONS, NUM_AUTHORS) + assert array_dict["author_institution_index"].shape == ( + NUM_AUTHORS, NUM_INSTITUTIONS) + + # pytype: enable=attribute-error + + return array_dict + + +def add_nodes_year(graph, paper_year): + nodes = graph.nodes.copy() + indices = nodes["index"] + year = paper_year[np.minimum(indices, paper_year.shape[0] - 1)].copy() + year[nodes["type"] != 0] = 1900 + nodes["year"] = year + return graph._replace(nodes=nodes) + + +def add_nodes_label(graph, paper_label): + nodes = graph.nodes.copy() + indices = nodes["index"] + label = paper_label[np.minimum(indices, paper_label.shape[0] - 1)] + label[nodes["type"] != 0] = 0 + nodes["label"] = label + return graph._replace(nodes=nodes) + + +def add_nodes_embedding_from_array(graph, array): + """Adds embeddings from the sstable_service for the indices.""" + nodes = graph.nodes.copy() + indices = nodes["index"] + embedding_indices = indices.copy() + embedding_indices[nodes["type"] == 1] += NUM_PAPERS + embedding_indices[nodes["type"] == 2] += NUM_PAPERS + NUM_AUTHORS + + # Gather the embeddings for the indices. + nodes["features"] = array[embedding_indices] + return graph._replace(nodes=nodes) + + +def get_graph_subsampling_dataset( + prefix, arrays, shuffle_indices, ratio_unlabeled_data_to_labeled_data, + max_nodes, max_edges, + **subsampler_kwargs): + """Returns tf_dataset for online sampling.""" + + def generator(): + labeled_indices = arrays[f"{prefix}_indices"] + if ratio_unlabeled_data_to_labeled_data > 0: + num_unlabeled_data_to_add = int(ratio_unlabeled_data_to_labeled_data * + labeled_indices.shape[0]) + unlabeled_indices = np.random.choice( + NUM_PAPERS, size=num_unlabeled_data_to_add, replace=False) + root_node_indices = np.concatenate([labeled_indices, unlabeled_indices]) + else: + root_node_indices = labeled_indices + if shuffle_indices: + root_node_indices = root_node_indices.copy() + np.random.shuffle(root_node_indices) + + for index in root_node_indices: + graph = sub_sampler.subsample_graph( + index, + arrays["author_institution_index"], + arrays["institution_author_index"], + arrays["author_paper_index"], + arrays["paper_author_index"], + arrays["paper_paper_index"], + arrays["paper_paper_index_t"], + paper_years=arrays["paper_year"], + max_nodes=max_nodes, + max_edges=max_edges, + **subsampler_kwargs) + + graph = add_nodes_label(graph, arrays["paper_label"]) + graph = add_nodes_year(graph, arrays["paper_year"]) + graph = tf_graphs.GraphsTuple(*graph) + yield graph + + sample_graph = next(generator()) + + return tf.data.Dataset.from_generator( + generator, + output_signature=utils_tf.specs_from_graphs_tuple(sample_graph)) + + +def paper_features_to_author_features( + author_paper_index, paper_features): + """Averages paper features to authors.""" + assert paper_features.shape[0] == NUM_PAPERS + assert author_paper_index.shape[0] == NUM_AUTHORS + author_features = np.zeros( + [NUM_AUTHORS, paper_features.shape[1]], dtype=paper_features.dtype) + for author_i in range(NUM_AUTHORS): + paper_indices = author_paper_index[author_i].indices + author_features[author_i] = paper_features[paper_indices].mean( + axis=0, dtype=np.float32) + if author_i % 10000 == 0: + logging.info("%d/%d", author_i, NUM_AUTHORS) + return author_features + + +def author_features_to_institution_features( + institution_author_index, author_features): + """Averages author features to institutions.""" + assert author_features.shape[0] == NUM_AUTHORS + assert institution_author_index.shape[0] == NUM_INSTITUTIONS + institution_features = np.zeros( + [NUM_INSTITUTIONS, author_features.shape[1]], dtype=author_features.dtype) + for institution_i in range(NUM_INSTITUTIONS): + author_indices = institution_author_index[institution_i].indices + institution_features[institution_i] = author_features[ + author_indices].mean(axis=0, dtype=np.float32) + if institution_i % 10000 == 0: + logging.info("%d/%d", institution_i, NUM_INSTITUTIONS) + return institution_features + + +def generate_fused_paper_adjacency_matrix(neighbor_indices, neighbor_distances, + paper_paper_csr): + """Generates fused adjacency matrix for identical nodes.""" + # First construct set of identical node indices. + # NOTE: Since we take only top K=26 identical pairs for each node, this is not + # actually exhaustive. Also, if A and B are equal, and B and C are equal, + # this method would not necessarily detect A and C being equal. + # However, this should capture almost all cases. + logging.info("Generating fused paper adjacency matrix") + eps = 0.0 + mask = ((neighbor_indices != np.mgrid[:neighbor_indices.shape[0], :1]) & + (neighbor_distances <= eps)) + identical_pairs = list(map(tuple, np.nonzero(mask))) + del mask + + # Have a csc version for fast column access. + paper_paper_csc = paper_paper_csr.tocsc() + + # Construct new matrix as coo, starting off with original rows/cols. + paper_paper_coo = paper_paper_csr.tocoo() + new_rows = [paper_paper_coo.row] + new_cols = [paper_paper_coo.col] + + for pair in tqdm.tqdm(identical_pairs): + # STEP ONE: First merge papers being cited by the pair. + # Add edges from second paper, to all papers cited by first paper. + cited_by_first = paper_paper_csr.getrow(pair[0]).nonzero()[1] + if cited_by_first.shape[0] > 0: + new_rows.append(pair[1] * np.ones_like(cited_by_first)) + new_cols.append(cited_by_first) + + # Add edges from first paper, to all papers cited by second paper. + cited_by_second = paper_paper_csr.getrow(pair[1]).nonzero()[1] + if cited_by_second.shape[0] > 0: + new_rows.append(pair[0] * np.ones_like(cited_by_second)) + new_cols.append(cited_by_second) + + # STEP TWO: Then merge papers that cite the pair. + # Add edges to second paper, from all papers citing the first paper. + citing_first = paper_paper_csc.getcol(pair[0]).nonzero()[0] + if citing_first.shape[0] > 0: + new_rows.append(citing_first) + new_cols.append(pair[1] * np.ones_like(citing_first)) + + # Add edges to first paper, from all papers citing the second paper. + citing_second = paper_paper_csc.getcol(pair[1]).nonzero()[0] + if citing_second.shape[0] > 0: + new_rows.append(citing_second) + new_cols.append(pair[0] * np.ones_like(citing_second)) + + logging.info("Done with adjacency loop") + paper_paper_coo_shape = paper_paper_coo.shape + del paper_paper_csr + del paper_paper_csc + del paper_paper_coo + # All done; now concatenate everything together and form new matrix. + new_rows = np.concatenate(new_rows) + new_cols = np.concatenate(new_cols) + return sp.coo_matrix( + (np.ones_like(new_rows, dtype=np.bool), (new_rows, new_cols)), + shape=paper_paper_coo_shape).tocsr() + + +def generate_k_fold_splits( + train_idx, valid_idx, output_path, num_splits=NUM_K_FOLD_SPLITS): + """Generates splits adding fractions of the validation split to training.""" + output_path = Path(output_path) + np.random.seed(42) + valid_idx = np.random.permutation(valid_idx) + # Split into `num_parts` (almost) identically sized arrays. + valid_idx_parts = np.array_split(valid_idx, num_splits) + + for i in range(num_splits): + # Add all but the i'th subpart to training set. + new_train_idx = np.concatenate( + [train_idx, *valid_idx_parts[:i], *valid_idx_parts[i+1:]]) + # i'th subpart is validation set. + new_valid_idx = valid_idx_parts[i] + train_path = output_path / f"train_idx_{i}_{num_splits}.npy" + valid_path = output_path / f"valid_idx_{i}_{num_splits}.npy" + np.save(train_path, new_train_idx) + np.save(valid_path, new_valid_idx) + logging.info("Saved: %s", train_path) + logging.info("Saved: %s", valid_path) + + +def get_train_and_valid_idx_for_split( + split_id: int, + num_splits: int, + root_path: str, +) -> Tuple[np.ndarray, np.ndarray]: + """Returns train and valid indices for given split.""" + new_train_idx = load_npy(f"{root_path}/train_idx_{split_id}_{num_splits}.npy") + new_valid_idx = load_npy(f"{root_path}/valid_idx_{split_id}_{num_splits}.npy") + return new_train_idx, new_valid_idx + + +def generate_fused_node_labels(neighbor_indices, neighbor_distances, + node_labels, train_indices, valid_indices, + test_indices): + """Generates fused adjacency matrix for identical nodes.""" + logging.info("Generating fused node labels") + valid_indices = set(valid_indices.tolist()) + test_indices = set(test_indices.tolist()) + valid_or_test_indices = valid_indices | test_indices + + train_indices = train_indices[train_indices < neighbor_indices.shape[0]] + # Go through list of all pairs where one node is in training set, and + for i in tqdm.tqdm(train_indices): + for j in range(neighbor_indices.shape[1]): + other_index = neighbor_indices[i][j] + # if the other is not a validation or test node, + if other_index in valid_or_test_indices: + continue + # and they are identical, + if neighbor_distances[i][j] == 0: + # assign the label of the training node to the other node + node_labels[other_index] = node_labels[i] + + return node_labels + + +def _pad_to_shape( + sparse_csr_matrix: sp.csr_matrix, + output_shape: Tuple[int, int]) -> sp.csr_matrix: + """Pads a csr sparse matrix to the given shape.""" + + # We should not try to expand anything smaller. + assert np.all(sparse_csr_matrix.shape <= output_shape) + + # Maybe it already has the right shape. + if sparse_csr_matrix.shape == output_shape: + return sparse_csr_matrix + + # Append as many indptr elements as we need to match the leading size, + # This is achieved by just padding with copies of the last indptr element. + required_padding = output_shape[0] - sparse_csr_matrix.shape[0] + updated_indptr = np.concatenate( + [sparse_csr_matrix.indptr] + + [sparse_csr_matrix.indptr[-1:]] * required_padding, + axis=0) + + # The change in trailing size does not have structural implications, it just + # determines the highest possible value for the indices, so it is sufficient + # to just pass the new output shape, with the correct trailing size. + return sp.csr.csr_matrix( + (sparse_csr_matrix.data, + sparse_csr_matrix.indices, + updated_indptr), + shape=output_shape) + + +def _fix_adjacency_shapes( + arrays: Dict[str, sp.csr.csr_matrix], + ) -> Dict[str, sp.csr.csr_matrix]: + """Fixes the shapes of the adjacency matrices.""" + arrays = arrays.copy() + for key in ["author_institution_index", + "author_paper_index", + "paper_paper_index", + "institution_author_index", + "paper_author_index", + "paper_paper_index_t"]: + type_sender = key.split("_")[0] + type_receiver = key.split("_")[1] + arrays[key] = _pad_to_shape( + arrays[key], output_shape=(SIZES[type_sender], SIZES[type_receiver])) + return arrays diff --git a/ogb_lsc/mag/datasets.py b/ogb_lsc/mag/datasets.py new file mode 100644 index 0000000..dffd6cb --- /dev/null +++ b/ogb_lsc/mag/datasets.py @@ -0,0 +1,297 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MAG240M-LSC datasets.""" + +import threading +from typing import NamedTuple, Optional + + +import jax +import jraph +from ml_collections import config_dict +import numpy as np +import tensorflow.compat.v2 as tf +import tensorflow_datasets as tfds + +# pylint: disable=g-bad-import-order +# pytype: disable=import-error +import batching_utils +import data_utils + + +# We only want to load these arrays once for all threads. +# `get_arrays` uses an LRU cache which is not thread safe. +LOADING_RAW_ARRAYS_LOCK = threading.Lock() + +NUM_CLASSES = data_utils.NUM_CLASSES + + +_MAX_DEPTH_IN_SUBGRAPH = 3 + + +class Batch(NamedTuple): + """NamedTuple to represent batches of data.""" + graph: jraph.GraphsTuple + node_labels: np.ndarray + label_mask: np.ndarray + central_node_mask: np.ndarray + node_indices: np.ndarray + absolute_node_indices: np.ndarray + + +def build_dataset_iterator( + data_root: str, + split: str, + dynamic_batch_size_config: config_dict.ConfigDict, + online_subsampling_kwargs: dict, # pylint: disable=g-bare-generic + debug: bool = False, + is_training: bool = True, + k_fold_split_id: Optional[int] = None, + ratio_unlabeled_data_to_labeled_data: float = 0.0, + use_all_labels_when_not_training: bool = False, + use_dummy_adjacencies: bool = False, +): + """Returns an iterator over Batches from the dataset.""" + + if split == 'test': + use_all_labels_when_not_training = True + + if not is_training: + ratio_unlabeled_data_to_labeled_data = 0.0 + + # Load the master data arrays. + with LOADING_RAW_ARRAYS_LOCK: + array_dict = data_utils.get_arrays( + data_root, k_fold_split_id=k_fold_split_id, + use_dummy_adjacencies=use_dummy_adjacencies) + + node_labels = array_dict['paper_label'].reshape(-1) + train_indices = array_dict['train_indices'].astype(np.int32) + is_train_index = np.zeros(node_labels.shape[0], dtype=np.int32) + is_train_index[train_indices] = 1 + valid_indices = array_dict['valid_indices'].astype(np.int32) + is_valid_index = np.zeros(node_labels.shape[0], dtype=np.int32) + is_valid_index[valid_indices] = 1 + is_train_or_valid_index = is_train_index + is_valid_index + + def sstable_to_intermediate_graph(graph): + indices = tf.cast(graph.nodes['index'], tf.int32) + first_index = indices[..., 0] + + # Add an additional absolute index, but adding offsets to authors, and + # institution indices. + absolute_index = graph.nodes['index'] + is_author = graph.nodes['type'] == 1 + absolute_index = tf.where( + is_author, absolute_index + data_utils.NUM_PAPERS, absolute_index) + is_institution = graph.nodes['type'] == 2 + absolute_index = tf.where( + is_institution, + absolute_index + data_utils.NUM_PAPERS + data_utils.NUM_AUTHORS, + absolute_index) + + is_same_as_central_node = tf.math.equal(indices, first_index) + input_nodes = graph.nodes + graph = graph._replace( + nodes={ + 'one_hot_type': + tf.one_hot(tf.cast(input_nodes['type'], tf.int32), 3), + 'one_hot_depth': + tf.one_hot( + tf.cast(input_nodes['depth'], tf.int32), + _MAX_DEPTH_IN_SUBGRAPH), + 'year': + tf.expand_dims(input_nodes['year'], axis=-1), + 'label': + tf.one_hot( + tf.cast(input_nodes['label'], tf.int32), + NUM_CLASSES), + 'is_same_as_central_node': + is_same_as_central_node, + # Only first node in graph has a valid label. + 'is_central_node': + tf.one_hot(0, + tf.shape(input_nodes['label'])[0]), + 'index': + input_nodes['index'], + 'absolute_index': absolute_index, + }, + globals=tf.expand_dims(graph.globals, axis=-1), + ) + + return graph + + ds = data_utils.get_graph_subsampling_dataset( + split, + array_dict, + shuffle_indices=is_training, + ratio_unlabeled_data_to_labeled_data=ratio_unlabeled_data_to_labeled_data, + max_nodes=dynamic_batch_size_config.n_node - 1, # Keep space for pads. + max_edges=dynamic_batch_size_config.n_edge, + **online_subsampling_kwargs) + if debug: + ds = ds.take(50) + ds = ds.map( + sstable_to_intermediate_graph, + num_parallel_calls=tf.data.experimental.AUTOTUNE) + + if is_training: + ds = ds.shard(jax.process_count(), jax.process_index()) + ds = ds.shuffle(buffer_size=1 if debug else 128) + ds = ds.repeat() + ds = ds.prefetch(1 if debug else tf.data.experimental.AUTOTUNE) + np_ds = iter(tfds.as_numpy(ds)) + batched_np_ds = batching_utils.dynamically_batch( + np_ds, + **dynamic_batch_size_config, + ) + + def intermediate_graph_to_batch(graph): + central_node_mask = graph.nodes['is_central_node'] + label = graph.nodes['label'] + node_indices = graph.nodes['index'] + absolute_indices = graph.nodes['absolute_index'] + + ### Construct label as a feature for non-central nodes. + # First do a lookup with node indices, with a np.minimum to ensure we do not + # index out of bounds due to num_authors being larger than num_papers. + is_same_as_central_node = graph.nodes['is_same_as_central_node'] + capped_indices = np.minimum(node_indices, node_labels.shape[0] - 1) + label_as_feature = node_labels[capped_indices] + # Nodes which are not in train set should get `num_classes` label. + # Nodes in test set or non-arXiv nodes have -1 or nan labels. + + # Mask out invalid labels and non-papers. + use_label_as_feature = np.logical_and(label_as_feature >= 0, + graph.nodes['one_hot_type'][..., 0]) + if split == 'train' or not use_all_labels_when_not_training: + # Mask out validation papers and non-arxiv papers who + # got labels from fusing with arxiv papers. + use_label_as_feature = np.logical_and(is_train_index[capped_indices], + use_label_as_feature) + label_as_feature = np.where(use_label_as_feature, label_as_feature, + NUM_CLASSES) + # Mask out central node label in case it appears again. + label_as_feature = np.where(is_same_as_central_node, NUM_CLASSES, + label_as_feature) + # Nodes which are not papers get `NUM_CLASSES+1` label. + label_as_feature = np.where(graph.nodes['one_hot_type'][..., 0], + label_as_feature, NUM_CLASSES+1) + + nodes = { + 'label_as_feature': label_as_feature, + 'year': graph.nodes['year'], + 'bitstring_year': _get_bitstring_year_representation( + graph.nodes['year']), + 'one_hot_type': graph.nodes['one_hot_type'], + 'one_hot_depth': graph.nodes['one_hot_depth'], + } + + graph = graph._replace( + nodes=nodes, + globals={}, + ) + is_train_or_valid_node = np.logical_and( + is_train_or_valid_index[capped_indices], + graph.nodes['one_hot_type'][..., 0]) + if is_training: + label_mask = np.logical_and(central_node_mask, is_train_or_valid_node) + else: + # `label_mask` is used to index into valid central nodes by prediction + # calculator. Since that computation is only done when not training, and + # at that time we are guaranteed all central nodes have valid labels, + # we just set label_mask = central_node_mask when not training. + label_mask = central_node_mask + batch = Batch( + graph=graph, + node_labels=label, + central_node_mask=central_node_mask, + label_mask=label_mask, + node_indices=node_indices, + absolute_node_indices=absolute_indices) + + # Transform integers into one-hots. + batch = _add_one_hot_features_to_batch(batch) + + # Gather PCA features. + return _add_embeddings_to_batch(batch, array_dict['bert_pca_129']) + + batch_list = [] + for batch in batched_np_ds: + with jax.profiler.StepTraceAnnotation('batch_postprocessing'): + batch = intermediate_graph_to_batch(batch) + if is_training: + batch_list.append(batch) + if len(batch_list) == jax.local_device_count(): + yield jax.device_put_sharded(batch_list, jax.local_devices()) + batch_list = [] + else: + yield batch + + +def _get_bitstring_year_representation(year: np.ndarray): + """Return year as bitstring.""" + min_year = 1900 + max_training_year = 2018 + offseted_year = np.minimum(year, max_training_year) - min_year + return np.unpackbits(offseted_year.astype(np.uint8), axis=-1) + + +def _np_one_hot(targets: np.ndarray, nb_classes: int): + res = np.zeros(targets.shape + (nb_classes,), dtype=np.float16) + np.put_along_axis(res, targets.astype(np.int32)[..., None], 1.0, axis=-1) + return res + + +def _get_one_hot_year_representation( + year: np.ndarray, + one_hot_type: np.ndarray, +): + """Returns good representation for year.""" + # Bucket edges found based on quantiles to bucket into 20 equal sized buckets. + bucket_edges = np.array([ + 1964, 1975, 1983, 1989, 1994, 1998, 2001, 2004, + 2006, 2008, 2009, 2011, 2012, 2013, 2014, 2016, + 2017, # 2018, 2019, 2020 contain last-year-of-train, eval, test nodes + ]) + year = np.squeeze(year, axis=-1) + year_id = np.searchsorted(bucket_edges, year) + is_paper = one_hot_type[..., 0] + bucket_id_for_non_paper = len(bucket_edges) + 1 + bucket_id = np.where(is_paper, year_id, bucket_id_for_non_paper) + one_hot_year = _np_one_hot(bucket_id, len(bucket_edges) + 2) + return one_hot_year + + +def _add_one_hot_features_to_batch(batch: Batch) -> Batch: + """Transforms integer features into one-hot features.""" + nodes = batch.graph.nodes.copy() + nodes['one_hot_year'] = _get_one_hot_year_representation( + nodes['year'], nodes['one_hot_type']) + del nodes['year'] + + # NUM_CLASSES plus one category for papers for which a class is not provided + # and another for nodes that are not papers. + nodes['one_hot_label_as_feature'] = _np_one_hot( + nodes['label_as_feature'], NUM_CLASSES + 2) + del nodes['label_as_feature'] + return batch._replace(graph=batch.graph._replace(nodes=nodes)) + + +def _add_embeddings_to_batch(batch: Batch, embeddings: np.ndarray) -> Batch: + nodes = batch.graph.nodes.copy() + nodes['features'] = embeddings[batch.absolute_node_indices] + graph = batch.graph._replace(nodes=nodes) + return batch._replace(graph=graph) diff --git a/ogb_lsc/mag/download_mag.py b/ogb_lsc/mag/download_mag.py new file mode 100644 index 0000000..cef7f2c --- /dev/null +++ b/ogb_lsc/mag/download_mag.py @@ -0,0 +1,110 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Download data required for training and evaluating models.""" + +import pathlib + +from absl import app +from absl import flags +from absl import logging +from google.cloud import storage + +# pylint: disable=g-bad-import-order +import data_utils + +Path = pathlib.Path + + +_BUCKET_NAME = 'deepmind-ogb-lsc' +_MAX_DOWNLOAD_ATTEMPTS = 5 + +FLAGS = flags.FLAGS + +flags.DEFINE_enum('payload', None, ['data', 'models'], + 'Download "data" or "models"?') +flags.DEFINE_string('task_root', None, 'Local task root directory') + +DATA_RELATIVE_PATHS = ( + data_utils.RAW_NODE_YEAR_FILENAME, + data_utils.TRAIN_INDEX_FILENAME, + data_utils.VALID_INDEX_FILENAME, + data_utils.TEST_INDEX_FILENAME, + data_utils.K_FOLD_SPLITS_DIR, + data_utils.FUSED_NODE_LABELS_FILENAME, + data_utils.FUSED_PAPER_EDGES_FILENAME, + data_utils.FUSED_PAPER_EDGES_T_FILENAME, + data_utils.EDGES_AUTHOR_INSTITUTION, + data_utils.EDGES_INSTITUTION_AUTHOR, + data_utils.EDGES_AUTHOR_PAPER, + data_utils.EDGES_PAPER_AUTHOR, + data_utils.PCA_MERGED_FEATURES_FILENAME, + ) + + +class DataCorruptionError(Exception): + pass + + +def _get_gcs_root(): + return Path('mag') / FLAGS.payload + + +def _get_gcs_bucket(): + storage_client = storage.Client.create_anonymous_client() + return storage_client.bucket(_BUCKET_NAME) + + +def _write_blob_to_destination(blob, task_root, ignore_existing=True): + """Write the blob.""" + logging.info("Copying blob: '%s'", blob.name) + destination_path = Path(task_root) / Path(*Path(blob.name).parts[1:]) + logging.info(" ... to: '%s'", str(destination_path)) + if ignore_existing and destination_path.exists(): + return + destination_path.parent.mkdir(parents=True, exist_ok=True) + checksum = 'crc32c' + for attempt in range(_MAX_DOWNLOAD_ATTEMPTS): + try: + blob.download_to_filename(destination_path.as_posix(), checksum=checksum) + except storage.client.resumable_media.common.DataCorruption: + pass + else: + break + else: + raise DataCorruptionError(f"Checksum ('{checksum}') for {blob.name} failed " + f'after {attempt + 1} attempts') + + +def main(unused_argv): + bucket = _get_gcs_bucket() + if FLAGS.payload == 'data': + relative_paths = DATA_RELATIVE_PATHS + else: + relative_paths = (None,) + for relative_path in relative_paths: + if relative_path is None: + relative_path = str(_get_gcs_root()) + else: + relative_path = str(_get_gcs_root() / relative_path) + logging.info("Copying relative path: '%s'", relative_path) + blobs = bucket.list_blobs(prefix=relative_path) + for blob in blobs: + _write_blob_to_destination(blob, FLAGS.task_root) + + +if __name__ == '__main__': + flags.mark_flag_as_required('payload') + flags.mark_flag_as_required('task_root') + app.run(main) diff --git a/ogb_lsc/mag/ensemble_predictions.py b/ogb_lsc/mag/ensemble_predictions.py new file mode 100644 index 0000000..3b0e77d --- /dev/null +++ b/ogb_lsc/mag/ensemble_predictions.py @@ -0,0 +1,227 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Ensemble k-fold predictions and generate final submission file.""" + +import collections +import os + +from absl import app +from absl import flags +from absl import logging +import dill +import jax +import numpy as np +from ogb import lsc + +# pylint: disable=g-bad-import-order +import data_utils +import losses + + +_NUM_KFOLD_SPLITS = 10 + +FLAGS = flags.FLAGS + + +_DATA_ROOT = flags.DEFINE_string('data_root', None, 'Path to the data root') +_SPLIT = flags.DEFINE_enum('split', None, ['valid', 'test'], 'Data split') +_PREDICTIONS_PATH = flags.DEFINE_string( + 'predictions_path', None, 'Path with the output of the k-fold models.') +_OUTPUT_PATH = flags.DEFINE_string('output_path', None, 'Output path.') + + +def _np_one_hot(targets: np.ndarray, nb_classes: int): + res = np.zeros(targets.shape + (nb_classes,), dtype=np.float32) + np.put_along_axis(res, targets.astype(np.int32)[..., None], 1.0, axis=-1) + return res + + +def ensemble_predictions( + node_idx_to_logits_list, + all_labels, + node_indices, + use_mode_break_tie_by_mean: bool = True, +): + """Ensemble together predictions for each node and generate final predictions.""" + # First, assert that each node has the same number of predictions to ensemble. + num_predictions_per_node = [ + len(x) for x in node_idx_to_logits_list.values() + ] + num_models = np.unique(num_predictions_per_node) + assert num_models.shape[0] == 1 + num_models = num_models[0] + # Gather all logits, shape should be [num_nodes, num_models, num_classes]. + all_logits = np.stack( + [np.stack(node_idx_to_logits_list[idx]) for idx in node_indices]) + assert all_logits.shape == (node_indices.shape[0], num_models, + data_utils.NUM_CLASSES) + # Softmax on the final axis. + all_probs = jax.nn.softmax(all_logits, axis=-1) + # Take average across models axis to get probabilities. + mean_probs = np.mean(all_probs, axis=1) + + # Assert there are no 2 equal logits for different classes. + max_logit_value = np.max(all_logits, axis=-1) + num_classes_with_max_value = ( + all_logits == max_logit_value[..., None]).sum(axis=-1) + + num_logit_ties = (num_classes_with_max_value > 1).sum() + if num_logit_ties: + logging.warn( + 'Found %d models with the exact same logits for two of the classes. ' + '`argmax` will choose the first.', num_logit_ties) + + # Each model votes on one class per type. + all_votes = np.argmax(all_logits, axis=-1) + assert all_votes.shape == (node_indices.shape[0], num_models) + + all_votes_one_hot = _np_one_hot(all_votes, data_utils.NUM_CLASSES) + assert all_votes_one_hot.shape == (node_indices.shape[0], num_models, + data_utils.NUM_CLASSES) + + num_votes_per_class = np.sum(all_votes_one_hot, axis=1) + assert num_votes_per_class.shape == ( + node_indices.shape[0], data_utils.NUM_CLASSES) + + if use_mode_break_tie_by_mean: + # Slight hack, give high weight to votes (any number > 1 works really) + # and add probabilities between [0, 1] per class to tie-break only within + # classes with equal votes. + total_score = 10 * num_votes_per_class + mean_probs + else: + # Just take mean. + total_score = mean_probs + + ensembled_logits = np.log(total_score) + return losses.Predictions( + node_indices=node_indices, + labels=all_labels, + logits=ensembled_logits, + predictions=np.argmax(ensembled_logits, axis=-1), + ) + + +def load_predictions(predictions_path, split): + """Loads set of predictions made by given XID.""" + + # Generate list of predictions per node. + # Note for validation each validation index is only present in exactly 1 + # model of the k-fold, however for test it is present in all of them. + node_idx_to_logits_list = collections.defaultdict(list) + + # For the 10 models in the ensemble. + for i in range(_NUM_KFOLD_SPLITS): + path = os.path.join(predictions_path, str(i)) + + # Find subdirectories. + # Directories will be something like: + # os.path.join(path, "step_104899_2021-06-14T18:20:05", "(test|valid).dill") + # So we make sure there is only one. + candidates = [] + for date_str in os.listdir(path): + candidate_path = os.path.join(path, date_str, f'{split}.dill') + if os.path.exists(candidate_path): + candidates.append(candidate_path) + if not candidates: + raise ValueError(f'No {split} predictions found at {path}') + elif len(candidates) > 1: + raise ValueError(f'Found more than one {split} predictions: {candidates}') + + path_for_kth_model_predictions = candidates[0] + with open(path_for_kth_model_predictions, 'rb') as f: + results = dill.load(f) + logging.info('Loaded %s', path_for_kth_model_predictions) + for (node_idx, logits) in zip(results.node_indices, + results.logits): + node_idx_to_logits_list[node_idx].append(logits) + + return node_idx_to_logits_list + + +def generate_ensembled_predictions( + data_root: str, predictions_path: str, split: str) -> losses.Predictions: + """Ensemble checkpoints from all WIDs in XID and generates submission file.""" + + array_dict = data_utils.get_arrays( + data_root=data_root, + return_pca_embeddings=False, + return_adjacencies=False) + + # Load all valid and test predictions. + node_idx_to_logits_list = load_predictions(predictions_path, split) + + # Assert that the indices loaded are as expected. + expected_idx = array_dict[f'{split}_indices'] + idx_found = np.array(list(node_idx_to_logits_list.keys())) + assert np.all(np.sort(idx_found) == expected_idx) + + if split == 'valid': + true_labels = array_dict['paper_label'][expected_idx.astype(np.int32)] + else: + # Don't know the test labels. + true_labels = np.full(expected_idx.shape, np.nan) + + # Ensemble together all predictions. + return ensemble_predictions( + node_idx_to_logits_list, true_labels, expected_idx) + + +def evaluate_validation(valid_predictions): + + evaluator = lsc.MAG240MEvaluator() + + evaluator_ouput = evaluator.eval( + dict(y_pred=valid_predictions.predictions.astype(np.float64), + y_true=valid_predictions.labels)) + logging.info( + 'Validation accuracy as reported by MAG240MEvaluator: %s', + evaluator_ouput) + + +def save_test_submission_file(test_predictions, output_dir): + evaluator = lsc.MAG240MEvaluator() + evaluator.save_test_submission( + dict(y_pred=test_predictions.predictions.astype(np.float64)), output_dir) + logging.info('Test submission file generated at %s', output_dir) + + +def main(argv): + del argv + + split = _SPLIT.value + + ensembled_predictions = generate_ensembled_predictions( + data_root=_DATA_ROOT.value, + predictions_path=_PREDICTIONS_PATH.value, + split=split) + + output_dir = _OUTPUT_PATH.value + os.makedirs(output_dir, exist_ok=True) + + if split == 'valid': + evaluate_validation(ensembled_predictions) + elif split == 'test': + save_test_submission_file(ensembled_predictions, output_dir) + + ensembled_predictions_path = os.path.join(output_dir, f'{split}.dill') + assert not os.path.exists(ensembled_predictions_path) + with open(ensembled_predictions_path, 'wb') as f: + dill.dump(ensembled_predictions, f) + logging.info( + '%s predictions stored at %s', split, ensembled_predictions_path) + + +if __name__ == '__main__': + app.run(main) diff --git a/ogb_lsc/mag/experiment.py b/ogb_lsc/mag/experiment.py new file mode 100644 index 0000000..1dfcbfd --- /dev/null +++ b/ogb_lsc/mag/experiment.py @@ -0,0 +1,725 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long + +r"""MAG240M-LSC Jaxline experiment. + +Usage: + + +``` +# A path pointing to the data root. +DATA_ROOT=/tmp/mag/data +# A path for checkpoints. +CHECKPOINT_DIR=/tmp/checkpoint/ +# A path for output predictions. +OUTPUT_DIR=/tmp/predictions/ +# Whether we are training a model of a k_fold of models (None for no k-fold) +K_FOLD_INDEX=0 + +``` + + +Some reusable arguments: + +``` +SHARED_ARGUMENTS="--config=ogb_lsc/mag/config.py \ + --config.experiment_kwargs.config.dataset_kwargs.data_root=${DATA_ROOT} \ + --config.experiment_kwargs.config.dataset_kwargs.k_fold_split_id=${K_FOLD_INDEX} \ + --config.checkpoint_dir=${CHECKPOINT_DIR}" +``` + +Train only: + ``` + python -m ogb_lsc.mag.experiment \ + ${SHARED_ARGUMENTS} --jaxline_mode="train" + RESTORE_PATH=${CHECKPOINT_DIR}/models/latest/step_${STEP}_${TIMESTAMP} + ``` + +Train with early stopping on a separate eval thread: + ``` + python -m ogb_lsc.mag.experiment \ + ${SHARED_ARGUMENTS} --jaxline_mode="train_eval_multithreaded" + RESTORE_PATH=${CHECKPOINT_DIR}/models/best/step_${STEP}_${TIMESTAMP} + ``` + +Produce predictions with a pretrained model: + ``` + SPLIT="valid" # Or "test" + EPOCHS_TO_ENSEMBLE=50 # We used this in the submission. + python -m ogb_lsc.mag.experiment \ + ${SHARED_ARGUMENTS} --jaxline_mode="eval" \ + --config.one_off_evaluate=True \ + --config.experiment_kwargs.config.num_eval_iterations_to_ensemble=${EPOCHS_TO_ENSEMBLE} \ + --config.restore_path=${RESTORE_PATH} \ + --config.experiment_kwargs.config.predictions_dir=${OUTPUT_DIR} \ + --config.experiment_kwargs.config.eval.split=${SPLIT} + ``` + +Note it is also possible to pass a `restore_path` with `--jaxline_mode="train"` +and training will continue where it left off. In the case of +`--jaxline_mode="train_eval_multithreaded"` this will also work, but early +stopping will not take into account any past best performance up to that +restored model. + + +Other useful options: + +To reduce the training batch size in case of OOM, for example for a batch size +of approximately 48 on average. + +``` + SHARED_ARGUMENTS="${SHARED_ARGUMENTS} \ + --config.experiment_kwargs.config.training.dynamic_batch_size_config.n_node=16320 \ + --config.experiment_kwargs.config.training.dynamic_batch_size_config.n_edge=34560 \ + --config.experiment_kwargs.config.training.dynamic_batch_size_config.n_graph=48" +``` + +To reduce lead time by using dummy adjacency matrices, instead of loading the +the full ones into memory. + +``` + SHARED_ARGUMENTS="${SHARED_ARGUMENTS} \ + --config.experiment_kwargs.config.dataset_kwargs.use_dummy_adjacencies=True" +``` + + +""" +# pylint: enable=line-too-long + + +import datetime +import functools +import os +import signal +import threading +from typing import Tuple + +from absl import app +from absl import flags +from absl import logging +import chex +import dill +import haiku as hk +import jax +from jax.config import config as jax_config +import jax.numpy as jnp +from jaxline import experiment +from jaxline import platform +from jaxline import utils +import jraph +from ml_collections import config_dict +import numpy as np +import optax +import tensorflow.compat.v2 as tf + +# pylint: disable=g-bad-import-order +import datasets +import losses +import models +import schedules + + +FLAGS = flags.FLAGS + + +class Experiment(experiment.AbstractExperiment): + """MAG240M-LSC Jaxline experiment.""" + + CHECKPOINT_ATTRS = { + '_params': 'params', + '_opt_state': 'opt_state', + '_network_state': 'network_state', + '_ema_network_state': 'ema_network_state', + '_ema_params': 'ema_params', + } + + def __init__( + self, + mode: str, + init_rng: jnp.ndarray, + config: config_dict.ConfigDict, + ): + """Initializes experiment.""" + super(Experiment, self).__init__(mode=mode, init_rng=init_rng) + tf.config.experimental.set_visible_devices([], device_type='gpu') + tf.config.experimental.set_visible_devices([], device_type='tpu') + + if mode not in ('train', 'eval', 'train_eval_multithreaded'): + raise ValueError(f'Invalid mode {mode}.') + + self.mode = mode + self.config = config + self.init_rng = init_rng + self.forward = hk.transform_with_state(self._forward_fn) + + self._predictions = None + + # Needed for checkpoint restore. + self._params = None + self._ema_params = None + self._network_state = None + self._ema_network_state = None + self._opt_state = None + + # Track what has started. + self._training = False + self._evaluating = False + + def _train_init(self): + iterator = self._build_numpy_dataset_iterator('train', is_training=True) + self._train_input = utils.py_prefetch(lambda: iterator) + dummy_batch = next(self._train_input) + + if self._params is None: + self._initialize_experiment_state(self.init_rng, dummy_batch) + self._update_func = jax.pmap( + self._update_func, + axis_name='i', + donate_argnums=3, + ) + self._training = True + + def _eval_init(self): + + split = self.config.eval.split + # Will build the iterator at each evaluation. + self._make_eval_dataset_iterator = functools.partial( + utils.py_prefetch, + lambda: self._build_numpy_dataset_iterator(split, is_training=False)) + self.eval_forward = jax.jit( + functools.partial(self.forward.apply, is_training=False)) + self._evaluating = True + + # _ _ + # | |_ _ __ __ _(_)_ __ + # | __| '__/ _` | | '_ \ + # | |_| | | (_| | | | | | + # \__|_| \__,_|_|_| |_| + # + + def step( + self, + global_step: jnp.ndarray, + rng: jnp.ndarray, + **unused_args, + ) -> losses.LogsDict: + """See Jaxline base class.""" + if not self._training: + self._train_init() + + with jax.profiler.StepTraceAnnotation('next_train_input'): + batch = next(self._train_input) + + with jax.profiler.StepTraceAnnotation('update_step'): + (self._params, self._ema_params, self._network_state, + self._ema_network_state, self._opt_state, stats) = self._update_func( + self._params, + self._ema_params, + self._network_state, + self._ema_network_state, + self._opt_state, + global_step, + rng, + batch, + ) + del batch # Buffers donated to _update_func. + + with jax.profiler.StepTraceAnnotation('get_stats'): + stats = utils.get_first(stats) + return stats + + def _build_numpy_dataset_iterator(self, split: str, is_training: bool): + if is_training: + dynamic_batch_size_config = self.config.training.dynamic_batch_size_config + else: + dynamic_batch_size_config = self.config.eval.dynamic_batch_size_config + return datasets.build_dataset_iterator( + split=split, + dynamic_batch_size_config=dynamic_batch_size_config, + debug=self.config.debug, + is_training=is_training, + **self.config.dataset_kwargs) + + def _initialize_experiment_state( + self, + init_rng: jnp.ndarray, + dummy_batch: datasets.Batch, + ): + """Initialize parameters and opt state if not restoring from checkpoint.""" + dummy_graph = dummy_batch.graph + + # Cast features to float32 so that parameters are as appropriate. + dummy_graph = dummy_graph._replace( + nodes=jax.tree_map(lambda x: x.astype(np.float32), dummy_graph.nodes), + edges=jax.tree_map(lambda x: x.astype(np.float32), dummy_graph.edges), + ) + init_key = utils.bcast_local_devices(init_rng) + p_init = jax.pmap(functools.partial(self.forward.init, is_training=True)) + params, network_state = p_init(init_key, dummy_graph) + opt_init, _ = self._optimizer( + utils.bcast_local_devices(jnp.zeros([], jnp.int32))) + opt_state = jax.pmap(opt_init)(params) + + # For EMA decay to work correctly, params/state must be floats. + chex.assert_type(jax.tree_leaves(params), jnp.floating) + chex.assert_type(jax.tree_leaves(network_state), jnp.floating) + + self._params = params + self._ema_params = params + self._network_state = network_state + self._ema_network_state = network_state + self._opt_state = opt_state + + def _get_learning_rate(self, global_step: jnp.ndarray) -> jnp.ndarray: + return schedules.learning_schedule( + global_step, + **self.config.optimizer.learning_rate_schedule, + ) + + def _optimizer( + self, + learning_rate: jnp.ndarray, + ) -> optax.GradientTransformation: + optimizer_fn = getattr(optax, self.config.optimizer.name) + return optimizer_fn( + learning_rate=learning_rate, + **self.config.optimizer.kwargs, + ) + + def _forward_fn( + self, + input_graph: jraph.GraphsTuple, + is_training: bool, + stop_gradient_embedding_to_logits: bool = False, + ): + model = models.NodePropertyEncodeProcessDecode( + num_classes=datasets.NUM_CLASSES, + **self.config.model_config, + ) + return model(input_graph, is_training, stop_gradient_embedding_to_logits) + + def _bgrl_loss( + self, + params: hk.Params, + ema_params: hk.Params, + network_state: hk.State, + ema_network_state: hk.State, + rng: jnp.ndarray, + batch: datasets.Batch, + ) -> Tuple[jnp.ndarray, Tuple[losses.LogsDict, hk.State]]: + """Computes fully supervised loss.""" + + # First compute 2 graph corrupted views. + first_corruption_key, second_corruption_key, rng = jax.random.split(rng, 3) + (first_model_key, first_model_key_ema, second_model_key, + second_model_key_ema, rng) = jax.random.split(rng, 5) + first_corrupted_graph = losses.get_corrupted_view( + batch.graph, + rng_key=first_corruption_key, + **self.config.training.loss_config.bgrl_loss_config.first_graph_corruption_config, # pylint:disable=line-too-long + ) + second_corrupted_graph = losses.get_corrupted_view( + batch.graph, + rng_key=second_corruption_key, + **self.config.training.loss_config.bgrl_loss_config.second_graph_corruption_config, # pylint:disable=line-too-long + ) + + # Then run the model on both. + first_corrupted_output, _ = self.forward.apply( + params, + network_state, + first_model_key, + first_corrupted_graph, + is_training=True, + stop_gradient_embedding_to_logits=True, + ) + second_corrupted_output, _ = self.forward.apply( + params, + network_state, + second_model_key, + second_corrupted_graph, + is_training=True, + stop_gradient_embedding_to_logits=True, + ) + first_corrupted_output_ema, _ = self.forward.apply( + ema_params, + ema_network_state, + first_model_key_ema, + first_corrupted_graph, + is_training=True, + stop_gradient_embedding_to_logits=True, + ) + second_corrupted_output_ema, _ = self.forward.apply( + ema_params, + ema_network_state, + second_model_key_ema, + second_corrupted_graph, + is_training=True, + stop_gradient_embedding_to_logits=True, + ) + + # These also contain projections for non-central nodes; remove them. + num_nodes_per_graph = batch.graph.n_node + node_central_indices = jnp.concatenate( + [jnp.array([0]), jnp.cumsum(num_nodes_per_graph[:-1])]) + bgrl_loss, bgrl_stats = losses.bgrl_loss( + first_online_predictions=first_corrupted_output + .node_projection_predictions[node_central_indices], + second_target_projections=second_corrupted_output_ema + .node_embedding_projections[node_central_indices], + second_online_predictions=second_corrupted_output + .node_projection_predictions[node_central_indices], + first_target_projections=first_corrupted_output_ema + .node_embedding_projections[node_central_indices], + symmetrize=self.config.training.loss_config.bgrl_loss_config.symmetrize, + valid_mask=batch.central_node_mask[node_central_indices], + ) + + # Finally train decoder on original graph with optional stop gradient. + stop_gradient = ( + self.config.training.loss_config.bgrl_loss_config + .stop_gradient_for_supervised_loss) + model_output, new_network_state = self.forward.apply( + params, + network_state, + rng, + batch.graph, + is_training=True, + stop_gradient_embedding_to_logits=stop_gradient, + ) + supervised_loss, supervised_stats = losses.node_classification_loss( + model_output.node_logits, + batch, + ) + stats = dict(**supervised_stats, **bgrl_stats) + total_loss = ( + supervised_loss + + self.config.training.loss_config.bgrl_loss_config.bgrl_loss_scale * + bgrl_loss) + return total_loss, (stats, new_network_state) + + def _loss( + self, + params: hk.Params, + ema_params: hk.Params, + network_state: hk.State, + ema_network_state: hk.State, + rng: jnp.ndarray, + batch: datasets.Batch, + ) -> Tuple[jnp.ndarray, Tuple[losses.LogsDict, hk.State]]: + """Compute loss from params and batch.""" + + # Cast to float32 since some losses are unstable with float16. + graph = batch.graph._replace( + nodes=jax.tree_map(lambda x: x.astype(jnp.float32), batch.graph.nodes), + edges=jax.tree_map(lambda x: x.astype(jnp.float32), batch.graph.edges), + ) + batch = batch._replace(graph=graph) + return self._bgrl_loss(params, ema_params, network_state, ema_network_state, + rng, batch) + + def _update_func( + self, + params: hk.Params, + ema_params: hk.Params, + network_state: hk.State, + ema_network_state: hk.State, + opt_state: optax.OptState, + global_step: jnp.ndarray, + rng: jnp.ndarray, + batch: datasets.Batch, + ) -> Tuple[hk.Params, hk.Params, hk.State, hk.State, optax.OptState, + losses.LogsDict]: + """Updates parameters.""" + + grad_fn = jax.value_and_grad(self._loss, has_aux=True) + (_, (stats, new_network_state)), grads = grad_fn( + params, + ema_params, + network_state, + ema_network_state, + rng, + batch) + learning_rate = self._get_learning_rate(global_step) + _, opt_apply = self._optimizer(learning_rate) + grad = jax.lax.pmean(grads, axis_name='i') + updates, opt_state = opt_apply(grad, opt_state, params) + params = optax.apply_updates(params, updates) + + # Stats and logging. + param_norm = optax.global_norm(params) + grad_norm = optax.global_norm(grad) + ema_rate = schedules.ema_decay_schedule( + step=global_step, **self.config.eval.ema_annealing_schedule) + num_non_padded_nodes = ( + batch.graph.n_node.sum() - + jraph.get_number_of_padding_with_graphs_nodes(batch.graph)) + num_non_padded_edges = ( + batch.graph.n_edge.sum() - + jraph.get_number_of_padding_with_graphs_edges(batch.graph)) + num_non_padded_graphs = ( + batch.graph.n_node.shape[0] - + jraph.get_number_of_padding_with_graphs_graphs(batch.graph)) + avg_num_nodes = num_non_padded_nodes / num_non_padded_graphs + avg_num_edges = num_non_padded_edges / num_non_padded_graphs + stats.update( + dict( + global_step=global_step, + grad_norm=grad_norm, + param_norm=param_norm, + learning_rate=learning_rate, + ema_rate=ema_rate, + avg_num_nodes=avg_num_nodes, + avg_num_edges=avg_num_edges, + )) + ema_fn = (lambda x, y: # pylint:disable=g-long-lambda + schedules.apply_ema_decay(x, y, ema_rate)) + ema_params = jax.tree_multimap(ema_fn, ema_params, params) + ema_network_state = jax.tree_multimap( + ema_fn, + ema_network_state, + network_state, + ) + return (params, ema_params, new_network_state, ema_network_state, opt_state, + stats) + + # _ + # _____ ____ _| | + # / _ \ \ / / _` | | + # | __/\ V / (_| | | + # \___| \_/ \__,_|_| + # + + def evaluate(self, global_step, rng, **unused_kwargs): + """See base class.""" + if not self._evaluating: + self._eval_init() + + global_step = np.array(utils.get_first(global_step)) + ema_params = utils.get_first(self._ema_params) + ema_network_state = utils.get_first(self._ema_network_state) + rng = utils.get_first(rng) + + # Evaluate using the ema params. + results, predictions = self._evaluate_with_ensemble(ema_params, + ema_network_state, rng) + results['global_step'] = global_step + + # Store predictions if we got a path. + self._maybe_save_predictions(predictions, global_step) + + return results + + def _evaluate_with_ensemble( + self, + params: hk.Params, + state: hk.State, + rng: jnp.ndarray, + ): + predictions_for_ensemble = [] + num_iterations = self.config.num_eval_iterations_to_ensemble + for iteration in range(num_iterations): + results, predictions = self._evaluate_params(params, state, rng) + self._log_results(f'Eval iteration {iteration}/{num_iterations}', results) + predictions_for_ensemble.append(predictions) + + if len(predictions_for_ensemble) > 1: + predictions = losses.ensemble_predictions_by_probability_average( + predictions_for_ensemble) + results = losses.get_accuracy_dict(predictions) + self._log_results(f'Ensembled {num_iterations} iterations', results) + return results, predictions + + def _maybe_save_predictions(self, predictions, global_step): + if not self.config.predictions_dir: + return + split = self.config.eval.split + output_dir = os.path.join( + self.config.predictions_dir, _get_step_date_label(global_step)) + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, split + '.dill') + + with open(output_path, 'wb') as f: + dill.dump(predictions, f) + logging.info('Saved %s predictions at: %s', split, output_path) + + def _evaluate_params( + self, + params: hk.Params, + state: hk.State, + rng: jnp.ndarray, + ): + """Evaluate given set of parameters.""" + num_valid = 0 + predictions_list = [] + labels_list = [] + logits_list = [] + indices_list = [] + for i, batch in enumerate(self._make_eval_dataset_iterator()): + model_output, _ = self.eval_forward( + params, + state, + rng, + batch.graph, + ) + + (masked_indices, + masked_predictions, + masked_labels, + masked_logits) = losses.get_predictions_labels_and_logits( + model_output.node_logits, batch) + predictions_list.append(masked_predictions) + indices_list.append(masked_indices) + labels_list.append(masked_labels) + logits_list.append(masked_logits) + + num_valid += jnp.sum(batch.label_mask) + + if i % 10 == 0: + logging.info('Generate predictons for %d batches so far', i + 1) + + predictions = losses.Predictions( + np.concatenate(indices_list, axis=0), + np.concatenate(labels_list, axis=0), + np.concatenate(predictions_list, axis=0), + np.concatenate(logits_list, axis=0)) + + if self.config.eval.split == 'test': + results = dict(num_valid=num_valid, accuracy=np.nan) + else: + results = losses.get_accuracy_dict(predictions) + + return results, predictions + + def _log_results(self, prefix, results): + logging_str = ', '.join( + ['{}={:.4f}'.format(k, float(results[k])) + for k in sorted(results.keys())]) + logging.info('%s: %s', prefix, logging_str) + + +def _restore_state_to_in_memory_checkpointer(restore_path): + """Initializes experiment state from a checkpoint.""" + + # Load pretrained experiment state. + python_state_path = os.path.join(restore_path, 'checkpoint.dill') + with open(python_state_path, 'rb') as f: + pretrained_state = dill.load(f) + logging.info('Restored checkpoint from %s', python_state_path) + + # Assign state to a dummy experiment instance for the in-memory checkpointer, + # broadcasting to devices. + dummy_experiment = Experiment( + mode='train', init_rng=0, config=FLAGS.config.experiment_kwargs.config) + for attribute, key in Experiment.CHECKPOINT_ATTRS.items(): + setattr(dummy_experiment, attribute, + utils.bcast_local_devices(pretrained_state[key])) + + jaxline_state = dict( + global_step=pretrained_state['global_step'], + experiment_module=dummy_experiment) + snapshot = utils.SnapshotNT(0, jaxline_state) + + # Finally, seed the jaxline `utils.InMemoryCheckpointer` global dict. + utils.GLOBAL_CHECKPOINT_DICT['latest'] = utils.CheckpointNT( + threading.local(), [snapshot]) + + +def _get_step_date_label(global_step): + # Date removing microseconds. + date_str = datetime.datetime.now().isoformat().split('.')[0] + return f'step_{global_step}_{date_str}' + + +def _save_state_from_in_memory_checkpointer( + save_path, experiment_class: experiment.AbstractExperiment): + """Saves experiment state to a checkpoint.""" + logging.info('Saving model.') + for checkpoint_name, checkpoint in utils.GLOBAL_CHECKPOINT_DICT.items(): + if not checkpoint.history: + logging.info('Nothing to save in "%s"', checkpoint_name) + continue + + pickle_nest = checkpoint.history[-1].pickle_nest + global_step = pickle_nest['global_step'] + + state_dict = {'global_step': global_step} + for attribute, key in experiment_class.CHECKPOINT_ATTRS.items(): + state_dict[key] = utils.get_first( + getattr(pickle_nest['experiment_module'], attribute)) + save_dir = os.path.join( + save_path, checkpoint_name, _get_step_date_label(global_step)) + python_state_path = os.path.join(save_dir, 'checkpoint.dill') + os.makedirs(save_dir, exist_ok=True) + with open(python_state_path, 'wb') as f: + dill.dump(state_dict, f) + logging.info( + 'Saved "%s" checkpoint to %s', checkpoint_name, python_state_path) + + +def _setup_signals(save_model_fn): + """Sets up a signal for model saving.""" + # Save a model on Ctrl+C. + def sigint_handler(unused_sig, unused_frame): + # Ideally, rather than saving immediately, we would then "wait" for a good + # time to save. In practice this reads from an in-memory checkpoint that + # only saves every 30 seconds or so, so chances of race conditions are very + # small. + save_model_fn() + logging.info(r'Use `Ctrl+\` to save and exit.') + + # Exit on `Ctrl+\`, saving a model. + prev_sigquit_handler = signal.getsignal(signal.SIGQUIT) + def sigquit_handler(unused_sig, unused_frame): + # Restore previous handler early, just in case something goes wrong in the + # next lines, so it is possible to press again and exit. + signal.signal(signal.SIGQUIT, prev_sigquit_handler) + save_model_fn() + logging.info(r'Exiting on `Ctrl+\`') + + # Re-raise for clean exit. + os.kill(os.getpid(), signal.SIGQUIT) + + signal.signal(signal.SIGINT, sigint_handler) + signal.signal(signal.SIGQUIT, sigquit_handler) + + +def main(argv, experiment_class: experiment.AbstractExperiment): + + # Maybe restore a model. + restore_path = FLAGS.config.restore_path + if restore_path: + _restore_state_to_in_memory_checkpointer(restore_path) + + # Maybe save a model. + save_dir = os.path.join(FLAGS.config.checkpoint_dir, 'models') + if FLAGS.config.one_off_evaluate: + save_model_fn = lambda: None # No need to save checkpoint in this case. + else: + save_model_fn = functools.partial( + _save_state_from_in_memory_checkpointer, save_dir, experiment_class) + _setup_signals(save_model_fn) # Save on Ctrl+C (continue) or Ctrl+\ (exit). + + try: + platform.main(experiment_class, argv) + finally: + save_model_fn() # Save at the end of training or in case of exception. + + +if __name__ == '__main__': + jax_config.update('jax_debug_nans', False) + flags.mark_flag_as_required('config') + app.run(lambda argv: main(argv, Experiment)) diff --git a/ogb_lsc/mag/generate_validation_splits.py b/ogb_lsc/mag/generate_validation_splits.py new file mode 100644 index 0000000..b16caca --- /dev/null +++ b/ogb_lsc/mag/generate_validation_splits.py @@ -0,0 +1,51 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generates the k-fold validation splits.""" + +import os + +from absl import app +from absl import flags + +import data_utils + + +_DATA_ROOT = flags.DEFINE_string( + 'data_root', None, required=True, + help='Path containing the downloaded data.') + + +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', None, required=True, + help='Output directory to write the splits to') + + +def main(argv): + del argv + array_dict = data_utils.get_arrays( + data_root=_DATA_ROOT.value, + return_pca_embeddings=False, + return_adjacencies=False) + + os.makedirs(_OUTPUT_DIR.value, exist_ok=True) + data_utils.generate_k_fold_splits( + train_idx=array_dict['train_indices'], + valid_idx=array_dict['valid_indices'], + output_path=_OUTPUT_DIR.value, + num_splits=data_utils.NUM_K_FOLD_SPLITS) + + +if __name__ == '__main__': + app.run(main) diff --git a/ogb_lsc/mag/losses.py b/ogb_lsc/mag/losses.py new file mode 100644 index 0000000..084db97 --- /dev/null +++ b/ogb_lsc/mag/losses.py @@ -0,0 +1,199 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Losses and related utilities.""" + +from typing import Mapping, Tuple, Sequence, NamedTuple, Dict, Optional +import jax +import jax.numpy as jnp +import jraph +import numpy as np + +# pylint: disable=g-bad-import-order +import datasets + +LogsDict = Mapping[str, jnp.ndarray] + + +class Predictions(NamedTuple): + node_indices: np.ndarray + labels: np.ndarray + predictions: np.ndarray + logits: np.ndarray + + +def node_classification_loss( + logits: jnp.ndarray, + batch: datasets.Batch, + extra_stats: bool = False, +) -> Tuple[jnp.ndarray, LogsDict]: + """Gets node-wise classification loss and statistics.""" + log_probs = jax.nn.log_softmax(logits) + loss = -jnp.sum(log_probs * batch.node_labels, axis=-1) + + num_valid = jnp.sum(batch.label_mask) + labels = jnp.argmax(batch.node_labels, axis=-1) + is_correct = (jnp.argmax(log_probs, axis=-1) == labels) + num_correct = jnp.sum(is_correct * batch.label_mask) + loss = jnp.sum(loss * batch.label_mask) / (num_valid + 1e-8) + accuracy = num_correct / (num_valid + 1e-8) + + entropy = -jnp.mean(jnp.sum(jax.nn.softmax(logits) * log_probs, axis=-1)) + + stats = { + 'classification_loss': loss, + 'prediction_entropy': entropy, + 'accuracy': accuracy, + 'num_valid': num_valid, + 'num_correct': num_correct, + } + if extra_stats: + for k in range(1, 6): + stats[f'top_{k}_correct'] = topk_correct(logits, labels, + batch.label_mask, k) + return loss, stats + + +def get_predictions_labels_and_logits( + logits: jnp.ndarray, + batch: datasets.Batch, +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Gets prediction labels and logits.""" + mask = batch.label_mask > 0. + indices = batch.node_indices[mask] + logits = logits[mask] + predictions = jnp.argmax(logits, axis=-1) + labels = jnp.argmax(batch.node_labels[mask], axis=-1) + return indices, predictions, labels, logits + + +def topk_correct( + logits: jnp.ndarray, + labels: jnp.ndarray, + valid_mask: jnp.ndarray, + topk: int, +) -> jnp.ndarray: + """Calculates top-k accuracy.""" + pred_ranking = jnp.argsort(logits, axis=1)[:, ::-1] + pred_ranking = pred_ranking[:, :topk] + is_correct = jnp.any(pred_ranking == labels[:, jnp.newaxis], axis=1) + return (is_correct * valid_mask).sum() + + +def ensemble_predictions_by_probability_average( + predictions_list: Sequence[Predictions]) -> Predictions: + """Ensemble predictions by ensembling the probabilities.""" + _assert_consistent_predictions(predictions_list) + all_probs = np.stack([ + jax.nn.softmax(predictions.logits, axis=-1) + for predictions in predictions_list + ], + axis=0) + ensembled_logits = np.log(all_probs.mean(0)) + return predictions_list[0]._replace( + logits=ensembled_logits, predictions=np.argmax(ensembled_logits, axis=-1)) + + +def get_accuracy_dict(predictions: Predictions) -> Dict[str, float]: + """Returns the accuracy dict.""" + output_dict = {} + output_dict['num_valid'] = predictions.predictions.shape[0] + matches = (predictions.labels == predictions.predictions) + output_dict['accuracy'] = matches.mean() + + pred_ranking = jnp.argsort(predictions.logits, axis=1)[:, ::-1] + for k in range(1, 6): + matches = jnp.any( + pred_ranking[:, :k] == predictions.labels[:, None], axis=1) + output_dict[f'top_{k}_correct'] = matches.mean() + return output_dict + + +def bgrl_loss( + first_online_predictions: jnp.ndarray, + second_target_projections: jnp.ndarray, + second_online_predictions: jnp.ndarray, + first_target_projections: jnp.ndarray, + symmetrize: bool, + valid_mask: jnp.ndarray, +) -> Tuple[jnp.ndarray, LogsDict]: + """Implements BGRL loss.""" + first_side_node_loss = jnp.sum( + jnp.square( + _l2_normalize(first_online_predictions, axis=-1) - + _l2_normalize(second_target_projections, axis=-1)), + axis=-1) + if symmetrize: + second_side_node_loss = jnp.sum( + jnp.square( + _l2_normalize(second_online_predictions, axis=-1) - + _l2_normalize(first_target_projections, axis=-1)), + axis=-1) + node_loss = first_side_node_loss + second_side_node_loss + else: + node_loss = first_side_node_loss + loss = (node_loss * valid_mask).sum() / (valid_mask.sum() + 1e-6) + return loss, dict(bgrl_loss=loss) + + +def get_corrupted_view( + graph: jraph.GraphsTuple, + feature_drop_prob: float, + edge_drop_prob: float, + rng_key: jnp.ndarray, +) -> jraph.GraphsTuple: + """Returns corrupted graph view.""" + node_key, edge_key = jax.random.split(rng_key) + + def mask_feature(x): + mask = jax.random.bernoulli(node_key, 1 - feature_drop_prob, x.shape) + return x * mask + + # Randomly mask features with fixed probability. + nodes = jax.tree_map(mask_feature, graph.nodes) + + # Simulate dropping of edges by changing genuine edges to self-loops on + # the padded node. + num_edges = graph.senders.shape[0] + last_node_idx = graph.n_node.sum() - 1 + edge_mask = jax.random.bernoulli(edge_key, 1 - edge_drop_prob, [num_edges]) + senders = jnp.where(edge_mask, graph.senders, last_node_idx) + receivers = jnp.where(edge_mask, graph.receivers, last_node_idx) + # Note that n_edge will now be invalid since edges in the middle of the list + # will correspond to the final graph. Set n_edge to None to ensure we do not + # accidentally use this. + return graph._replace( + nodes=nodes, + senders=senders, + receivers=receivers, + n_edge=None, + ) + + +def _assert_consistent_predictions(predictions_list: Sequence[Predictions]): + first_predictions = predictions_list[0] + for predictions in predictions_list: + assert np.all(predictions.node_indices == first_predictions.node_indices) + assert np.all(predictions.labels == first_predictions.labels) + assert np.all( + predictions.predictions == np.argmax(predictions.logits, axis=-1)) + + +def _l2_normalize( + x: jnp.ndarray, + axis: Optional[int] = None, + epsilon: float = 1e-6, +) -> jnp.ndarray: + return x * jax.lax.rsqrt( + jnp.sum(jnp.square(x), axis=axis, keepdims=True) + epsilon) diff --git a/ogb_lsc/mag/models.py b/ogb_lsc/mag/models.py new file mode 100644 index 0000000..552490f --- /dev/null +++ b/ogb_lsc/mag/models.py @@ -0,0 +1,270 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""MAG240M-LSC models.""" + +from typing import Callable, NamedTuple, Sequence + +import haiku as hk +import jax +import jax.numpy as jnp +import jraph + + +_REDUCER_NAMES = { + 'sum': + jax.ops.segment_sum, + 'mean': + jraph.segment_mean, + 'softmax': + jraph.segment_softmax, +} + + +class ModelOutput(NamedTuple): + node_embeddings: jnp.ndarray + node_embedding_projections: jnp.ndarray + node_projection_predictions: jnp.ndarray + node_logits: jnp.ndarray + + +def build_update_fn( + name: str, + output_sizes: Sequence[int], + activation: Callable[[jnp.ndarray], jnp.ndarray], + normalization_type: str, + is_training: bool, +): + """Builds update function.""" + + def single_mlp(inner_name: str): + """Creates a single MLP performing the update.""" + mlp = hk.nets.MLP( + output_sizes=output_sizes, + name=inner_name, + activation=activation) + mlp = jraph.concatenated_args(mlp) + if normalization_type == 'layer_norm': + norm = hk.LayerNorm( + axis=-1, + create_scale=True, + create_offset=True, + name=name + '_layer_norm') + elif normalization_type == 'batch_norm': + batch_norm = hk.BatchNorm( + create_scale=True, + create_offset=True, + decay_rate=0.9, + name=f'{inner_name}_batch_norm', + cross_replica_axis=None if hk.running_init() else 'i', + ) + norm = lambda x: batch_norm(x, is_training) + elif normalization_type == 'none': + return mlp + else: + raise ValueError(f'Unknown normalization type {normalization_type}') + return jraph.concatenated_args(hk.Sequential([mlp, norm])) + + return single_mlp(f'{name}_homogeneous') + + +def build_gn( + output_sizes: Sequence[int], + activation: Callable[[jnp.ndarray], jnp.ndarray], + suffix: str, + use_sent_edges: bool, + is_training: bool, + dropedge_rate: float, + normalization_type: str, + aggregation_function: str, +): + """Builds an InteractionNetwork with MLP update functions.""" + node_update_fn = build_update_fn( + f'node_processor_{suffix}', + output_sizes, + activation=activation, + normalization_type=normalization_type, + is_training=is_training, + ) + edge_update_fn = build_update_fn( + f'edge_processor_{suffix}', + output_sizes, + activation=activation, + normalization_type=normalization_type, + is_training=is_training, + ) + + def maybe_dropedge(x): + """Dropout on edge messages.""" + if not is_training: + return x + return x * hk.dropout( + hk.next_rng_key(), + dropedge_rate, + jnp.ones([x.shape[0], 1]), + ) + + dropped_edge_update_fn = lambda *args: maybe_dropedge(edge_update_fn(*args)) + return jraph.InteractionNetwork( + update_edge_fn=dropped_edge_update_fn, + update_node_fn=node_update_fn, + aggregate_edges_for_nodes_fn=_REDUCER_NAMES[aggregation_function], + include_sent_messages_in_node_update=use_sent_edges, + ) + + +def _get_activation_fn(name: str) -> Callable[[jnp.ndarray], jnp.ndarray]: + if name == 'identity': + return lambda x: x + if hasattr(jax.nn, name): + return getattr(jax.nn, name) + raise ValueError('Unknown activation function %s specified. ' + 'See https://jax.readthedocs.io/en/latest/jax.nn.html' + 'for the list of supported function names.') + + +class NodePropertyEncodeProcessDecode(hk.Module): + """Node Property Prediction Encode Process Decode Model.""" + + def __init__( + self, + mlp_hidden_sizes: Sequence[int], + latent_size: int, + num_classes: int, + num_message_passing_steps: int = 2, + activation: str = 'relu', + dropout_rate: float = 0.0, + dropedge_rate: float = 0.0, + use_sent_edges: bool = False, + disable_edge_updates: bool = False, + normalization_type: str = 'layer_norm', + aggregation_function: str = 'sum', + name='NodePropertyEncodeProcessDecode', + ): + super().__init__(name=name) + self._num_classes = num_classes + self._latent_size = latent_size + self._output_sizes = list(mlp_hidden_sizes) + [latent_size] + self._num_message_passing_steps = num_message_passing_steps + self._activation = _get_activation_fn(activation) + self._dropout_rate = dropout_rate + self._dropedge_rate = dropedge_rate + self._use_sent_edges = use_sent_edges + self._disable_edge_updates = disable_edge_updates + self._normalization_type = normalization_type + self._aggregation_function = aggregation_function + + def _dropout_graph(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: + node_key, edge_key = hk.next_rng_keys(2) + nodes = hk.dropout(node_key, self._dropout_rate, graph.nodes) + edges = graph.edges + if not self._disable_edge_updates: + edges = hk.dropout(edge_key, self._dropout_rate, edges) + return graph._replace(nodes=nodes, edges=edges) + + def _encode( + self, + graph: jraph.GraphsTuple, + is_training: bool, + ) -> jraph.GraphsTuple: + node_embed_fn = build_update_fn( + 'node_encoder', + self._output_sizes, + activation=self._activation, + normalization_type=self._normalization_type, + is_training=is_training, + ) + edge_embed_fn = build_update_fn( + 'edge_encoder', + self._output_sizes, + activation=self._activation, + normalization_type=self._normalization_type, + is_training=is_training, + ) + gn = jraph.GraphMapFeatures(edge_embed_fn, node_embed_fn) + graph = gn(graph) + if is_training: + graph = self._dropout_graph(graph) + return graph + + def _process( + self, + graph: jraph.GraphsTuple, + is_training: bool, + ) -> jraph.GraphsTuple: + for idx in range(self._num_message_passing_steps): + net = build_gn( + output_sizes=self._output_sizes, + activation=self._activation, + suffix=str(idx), + use_sent_edges=self._use_sent_edges, + is_training=is_training, + dropedge_rate=self._dropedge_rate, + normalization_type=self._normalization_type, + aggregation_function=self._aggregation_function) + residual_graph = net(graph) + graph = graph._replace(nodes=graph.nodes + residual_graph.nodes) + if not self._disable_edge_updates: + graph = graph._replace(edges=graph.edges + residual_graph.edges) + if is_training: + graph = self._dropout_graph(graph) + return graph + + def _node_mlp( + self, + graph: jraph.GraphsTuple, + is_training: bool, + output_size: int, + name: str, + ) -> jnp.ndarray: + decoder_sizes = list(self._output_sizes[:-1]) + [output_size] + net = build_update_fn( + name, + decoder_sizes, + self._activation, + normalization_type=self._normalization_type, + is_training=is_training, + ) + return net(graph.nodes) + + def __call__( + self, + graph: jraph.GraphsTuple, + is_training: bool, + stop_gradient_embedding_to_logits: bool = False, + ) -> ModelOutput: + # Note that these update configs may need to change if + # we switch back to GraphNetwork rather than InteractionNetwork. + + graph = self._encode(graph, is_training) + graph = self._process(graph, is_training) + node_embeddings = graph.nodes + node_projections = self._node_mlp(graph, is_training, self._latent_size, + 'projector') + node_predictions = self._node_mlp( + graph._replace(nodes=node_projections), + is_training, + self._latent_size, + 'predictor', + ) + if stop_gradient_embedding_to_logits: + graph = jax.tree_map(jax.lax.stop_gradient, graph) + node_logits = self._node_mlp(graph, is_training, self._num_classes, + 'logits_decoder') + return ModelOutput( + node_embeddings=node_embeddings, + node_logits=node_logits, + node_embedding_projections=node_projections, + node_projection_predictions=node_predictions, + ) diff --git a/ogb_lsc/mag/neighbor_builder.py b/ogb_lsc/mag/neighbor_builder.py new file mode 100644 index 0000000..27ea6bf --- /dev/null +++ b/ogb_lsc/mag/neighbor_builder.py @@ -0,0 +1,175 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Find neighborhoods around paper feature embeddings.""" + +import pathlib + +from absl import app +from absl import flags +from absl import logging +import annoy +import numpy as np +import scipy.sparse as sp + +# pylint: disable=g-bad-import-order +import data_utils + +Path = pathlib.Path + + +_PAPER_PAPER_B_PATH = 'ogb_mag_adjacencies/paper_paper_b.npz' + +FLAGS = flags.FLAGS +flags.DEFINE_string('data_root', None, 'Data root directory') + + +def _read_paper_pca_features(): + data_root = Path(FLAGS.data_root) + path = data_root / data_utils.PCA_PAPER_FEATURES_FILENAME + with open(path, 'rb') as fid: + return np.load(fid) + + +def _read_adjacency_indices(): + # Get adjacencies. + return data_utils.get_arrays( + data_root=FLAGS.data_root, + use_fused_node_labels=False, + use_fused_node_adjacencies=False, + return_pca_embeddings=False, + ) + + +def build_annoy_index(features): + """Build the Annoy index.""" + logging.info('Building annoy index') + num_vectors, vector_size = features.shape + annoy_index = annoy.AnnoyIndex(vector_size, 'euclidean') + for i, x in enumerate(features): + annoy_index.add_item(i, x) + if i % 1000000 == 0: + logging.info('Adding: %d / %d (%.3g %%)', i, num_vectors, + 100 * i / num_vectors) + n_trees = 10 + _ = annoy_index.build(n_trees) + return annoy_index + + +def _get_annoy_index_path(): + return Path(FLAGS.data_root) / data_utils.PREPROCESSED_DIR / 'annoy_index.ann' + + +def save_annoy_index(annoy_index): + logging.info('Saving annoy index') + index_path = _get_annoy_index_path() + index_path.parent.mkdir(parents=True, exist_ok=True) + annoy_index.save(str(index_path)) + + +def read_annoy_index(features): + index_path = _get_annoy_index_path() + vector_size = features.shape[1] + annoy_index = annoy.AnnoyIndex(vector_size, 'euclidean') + annoy_index.load(str(index_path)) + return annoy_index + + +def compute_neighbor_indices_and_distances(features): + """Use the pre-built Annoy index to compute neighbor indices and distances.""" + logging.info('Computing neighbors and distances') + annoy_index = read_annoy_index(features) + num_vectors = features.shape[0] + + k = 20 + pad_k = 5 + search_k = -1 + neighbor_indices = np.zeros([num_vectors, k + pad_k + 1], dtype=np.int32) + neighbor_distances = np.zeros([num_vectors, k + pad_k + 1], dtype=np.float32) + for i in range(num_vectors): + neighbor_indices[i], neighbor_distances[i] = annoy_index.get_nns_by_item( + i, k + pad_k + 1, search_k=search_k, include_distances=True) + if i % 10000 == 0: + logging.info('Finding neighbors %d / %d', i, num_vectors) + return neighbor_indices, neighbor_distances + + +def _write_neighbors(neighbor_indices, neighbor_distances): + """Write neighbor indices and distances.""" + logging.info('Writing neighbors') + indices_path = Path(FLAGS.data_root) / data_utils.NEIGHBOR_INDICES_FILENAME + distances_path = ( + Path(FLAGS.data_root) / data_utils.NEIGHBOR_DISTANCES_FILENAME) + indices_path.parent.mkdir(parents=True, exist_ok=True) + distances_path.parent.mkdir(parents=True, exist_ok=True) + with open(indices_path, 'wb') as fid: + np.save(fid, neighbor_indices) + with open(distances_path, 'wb') as fid: + np.save(fid, neighbor_distances) + + +def _write_fused_edges(fused_paper_adjacency_matrix): + """Write fused edges.""" + data_root = Path(FLAGS.data_root) + edges_path = data_root / data_utils.FUSED_PAPER_EDGES_FILENAME + edges_t_path = data_root / data_utils.FUSED_PAPER_EDGES_T_FILENAME + edges_path.parent.mkdir(parents=True, exist_ok=True) + edges_t_path.parent.mkdir(parents=True, exist_ok=True) + with open(edges_path, 'wb') as fid: + sp.save_npz(fid, fused_paper_adjacency_matrix) + with open(edges_t_path, 'wb') as fid: + sp.save_npz(fid, fused_paper_adjacency_matrix.T) + + +def _write_fused_nodes(fused_node_labels): + """Write fused nodes.""" + labels_path = Path(FLAGS.data_root) / data_utils.FUSED_NODE_LABELS_FILENAME + labels_path.parent.mkdir(parents=True, exist_ok=True) + with open(labels_path, 'wb') as fid: + np.save(fid, fused_node_labels) + + +def main(unused_argv): + paper_pca_features = _read_paper_pca_features() + # Find neighbors. + annoy_index = build_annoy_index(paper_pca_features) + save_annoy_index(annoy_index) + neighbor_indices, neighbor_distances = compute_neighbor_indices_and_distances( + paper_pca_features) + del paper_pca_features + _write_neighbors(neighbor_indices, neighbor_distances) + + data = _read_adjacency_indices() + paper_paper_csr = data['paper_paper_index'] + paper_label = data['paper_label'] + train_indices = data['train_indices'] + valid_indices = data['valid_indices'] + test_indices = data['test_indices'] + del data + + fused_paper_adjacency_matrix = data_utils.generate_fused_paper_adjacency_matrix( + neighbor_indices, neighbor_distances, paper_paper_csr) + _write_fused_edges(fused_paper_adjacency_matrix) + del fused_paper_adjacency_matrix + del paper_paper_csr + + fused_node_labels = data_utils.generate_fused_node_labels( + neighbor_indices, neighbor_distances, paper_label, train_indices, + valid_indices, test_indices) + _write_fused_nodes(fused_node_labels) + + +if __name__ == '__main__': + flags.mark_flag_as_required('data_root') + app.run(main) diff --git a/ogb_lsc/mag/organize_data.sh b/ogb_lsc/mag/organize_data.sh new file mode 100644 index 0000000..aab7a32 --- /dev/null +++ b/ogb_lsc/mag/organize_data.sh @@ -0,0 +1,67 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash + +set -e + +while getopts ":i:o:" opt; do + case ${opt} in + i ) + INPUT_DIR=$OPTARG + ;; + o ) + TASK_ROOT=$OPTARG + ;; + \? ) + echo "Usage: organize_data.sh -i -o " + ;; + : ) + echo "Invalid option: $OPTARG requires an argument" 1>&2 + ;; + esac +done +shift $((OPTIND -1)) + +# Get this script's directory. +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + +if [[ -z "${INPUT_DIR}" ]]; then + echo "Need INPUT_DIR argument (-i )" + exit 1 +fi + +if [[ -z "${TASK_ROOT}" ]]; then + echo "Need TASK_ROOT argument (-o )" + exit 1 +fi + +DATA_ROOT="${TASK_ROOT}"/data + +# Create raw directory to move all files to it. +mkdir "${INPUT_DIR}"/mag240m_kddcup2021/raw + +mv "${INPUT_DIR}"/mag240m_kddcup2021/processed/paper/node_feat.npy \ + "${INPUT_DIR}"/mag240m_kddcup2021/processed/paper/node_label.npy \ + "${INPUT_DIR}"/mag240m_kddcup2021/processed/paper/node_year.npy \ + "${DATA_ROOT}"/raw +mv "${INPUT_DIR}"/mag240m_kddcup2021/processed/author___affiliated_with___institution/edge_index.npy \ + "${DATA_ROOT}"/raw/author_affiliated_with_institution_edges.npy +mv "${ROOT}"/mag240m_kddcup2021/processed/author___writes___paper/edge_index.npy \ + "${DATA_ROOT}"/raw/author_writes_paper_edges.npy +mv "${ROOT}"/mag240m_kddcup2021/processed/paper___cites___paper/edge_index.npy \ + "${DATA_ROOT}"/raw/paper_cites_paper_edges.npy + +# Split and save the train/valid/test indices to the raw directory, with names +"train_idx.npy", "valid_idx.npy", "test_idx.npy": +python3 "${SCRIPT_DIR}"/split_and_save_indices.py --data_root=${DATA_ROOT} diff --git a/ogb_lsc/mag/pca_builder.py b/ogb_lsc/mag/pca_builder.py new file mode 100644 index 0000000..c267ad4 --- /dev/null +++ b/ogb_lsc/mag/pca_builder.py @@ -0,0 +1,166 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Apply PCA to the papers' BERT features. + +Compute papers' PCA features. +Recompute author and institution features from the paper PCA features. +""" + +import pathlib +import time + +from absl import app +from absl import flags +from absl import logging +import numpy as np + +# pylint: disable=g-bad-import-order +import data_utils + +Path = pathlib.Path + +_NUMBER_OF_PAPERS_TO_ESTIMATE_PCA_ON = 1000000 # None indicates all. + +FLAGS = flags.FLAGS +flags.DEFINE_string('data_root', None, 'Data root directory') + + +def _sample_vectors(vectors, num_samples, seed=0): + """Randomly sample some vectors.""" + rand = np.random.RandomState(seed=seed) + indices = rand.choice(vectors.shape[0], size=num_samples, replace=False) + return vectors[indices] + + +def _pca(feat): + """Returns evals (variances), evecs (rows are principal components).""" + cov = np.cov(feat.T) + _, evals, evecs = np.linalg.svd(cov, full_matrices=True) + return evals, evecs + + +def _read_raw_paper_features(): + """Load raw paper features.""" + path = Path(FLAGS.data_root) / data_utils.RAW_NODE_FEATURES_FILENAME + try: # Use mmap if possible. + features = np.load(path, mmap_mode='r') + except FileNotFoundError: + with open(path, 'rb') as fid: + features = np.load(fid) + return features + + +def _get_principal_components(features, + num_principal_components=129, + num_samples=10000, + seed=2, + dtype='f4'): + """Estimate PCA features.""" + sample = _sample_vectors( + features[:_NUMBER_OF_PAPERS_TO_ESTIMATE_PCA_ON], num_samples, seed=seed) + # Compute PCA basis. + _, evecs = _pca(sample) + return evecs[:num_principal_components].T.astype(dtype) + + +def _project_features_onto_principal_components(features, + principal_components, + block_size=1000000): + """Apply PCA iteratively.""" + num_principal_components = principal_components.shape[1] + dtype = principal_components.dtype + num_vectors = features.shape[0] + num_features = features.shape[0] + num_blocks = (num_features - 1) // block_size + 1 + pca_features = np.empty([num_vectors, num_principal_components], dtype=dtype) + # Loop through in blocks. + start_time = time.time() + for i in range(num_blocks): + i_start = i * block_size + i_end = (i + 1) * block_size + f = np.array(features[i_start:i_end].copy()) + pca_features[i_start:i_end] = np.dot(f, principal_components).astype(dtype) + del f + elapsed_time = time.time() - start_time + time_left = elapsed_time / (i + 1) * (num_blocks - i - 1) + logging.info('Features %d / %d. Elapsed time %.1f. Time left: %.1f', i_end, + num_vectors, elapsed_time, time_left) + return pca_features + + +def _read_adjacency_indices(): + # Get adjacencies. + return data_utils.get_arrays( + data_root=FLAGS.data_root, + use_fused_node_labels=False, + use_fused_node_adjacencies=False, + return_pca_embeddings=False, + ) + + +def _compute_author_pca_features(paper_pca_features, index_arrays): + return data_utils.paper_features_to_author_features( + index_arrays['author_paper_index'], paper_pca_features) + + +def _compute_institution_pca_features(author_pca_features, index_arrays): + return data_utils.author_features_to_institution_features( + index_arrays['institution_author_index'], author_pca_features) + + +def _write_array(path, array): + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, 'wb') as fid: + np.save(fid, array) + + +def main(unused_argv): + data_root = Path(FLAGS.data_root) + + raw_paper_features = _read_raw_paper_features() + principal_components = _get_principal_components(raw_paper_features) + paper_pca_features = _project_features_onto_principal_components( + raw_paper_features, principal_components) + del raw_paper_features + del principal_components + + paper_pca_path = data_root / data_utils.PCA_PAPER_FEATURES_FILENAME + author_pca_path = data_root / data_utils.PCA_AUTHOR_FEATURES_FILENAME + institution_pca_path = ( + data_root / data_utils.PCA_INSTITUTION_FEATURES_FILENAME) + merged_pca_path = data_root / data_utils.PCA_MERGED_FEATURES_FILENAME + _write_array(paper_pca_path, paper_pca_features) + + # Compute author and institution features from paper PCA features. + index_arrays = _read_adjacency_indices() + author_pca_features = _compute_author_pca_features(paper_pca_features, + index_arrays) + _write_array(author_pca_path, author_pca_features) + + institution_pca_features = _compute_institution_pca_features( + author_pca_features, index_arrays) + _write_array(institution_pca_path, institution_pca_features) + + merged_pca_features = np.concatenate( + [paper_pca_features, author_pca_features, institution_pca_features], + axis=0) + del author_pca_features + del institution_pca_features + _write_array(merged_pca_path, merged_pca_features) + + +if __name__ == '__main__': + flags.mark_flag_as_required('data_root') + app.run(main) diff --git a/ogb_lsc/mag/requirements.txt b/ogb_lsc/mag/requirements.txt new file mode 100644 index 0000000..b1f5666 --- /dev/null +++ b/ogb_lsc/mag/requirements.txt @@ -0,0 +1,23 @@ +wheel +absl-py>=0.12.0 +chex>=0.0.7 +dill>=0.3.3 +dm-haiku>=0.0.4 +# jaxline +git+https://github.com/deepmind/jaxline@927695a0b88e480d085590736e2daa36c2f2c68b +optax>=0.0.8 +ml_collections +numpy>=1.16.4 +tensorflow>=2.5.0 +tensorflow-datasets>=4.3.0 +dm-tree>=0.1.6 +ogb>=1.3.1 +# graph_nets +git+git://github.com/deepmind/graph_nets.git@64771dff0d74ca8e77b1f1dcd5a7d26634356d61 +scipy>=1.2.1 +jaxlib>=0.1.57+cuda110 +tqdm>=4.28.1 +annoy>=1.0.3 +# jraph +git+git://github.com/deepmind/jraph.git@80d2f9e4d82f841a71d56ba44fa9e8781e93ae1f +google-cloud-storage diff --git a/ogb_lsc/mag/run_preprocessing.sh b/ogb_lsc/mag/run_preprocessing.sh new file mode 100755 index 0000000..bed3c98 --- /dev/null +++ b/ogb_lsc/mag/run_preprocessing.sh @@ -0,0 +1,57 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#!/bin/bash + +set -e +set -x + + +while getopts ":r:" opt; do + case ${opt} in + r ) + TASK_ROOT=$OPTARG + ;; + \? ) + echo "Usage: preprocess_data.sh -r " + ;; + : ) + echo "Invalid option: $OPTARG requires an argument" 1>&2 + ;; + esac +done +shift $((OPTIND -1)) + +# Get this script's directory. +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + + +DATA_ROOT="${TASK_ROOT}"/data +PREPROCESSED_DIR="${DATA_ROOT}"/preprocessed + +# Create preprocessed directory to move all files to it. +mkdir -p "${PREPROCESSED_DIR}" + +# Run the CSR edge builder. +python "${SCRIPT_DIR}"/csr_builder.py --data_root="${DATA_ROOT}" + +# Run the PCA feature builder. +python "${SCRIPT_DIR}"/pca_builder.py --data_root="${DATA_ROOT}" + +# Run the neighbor-finder/fuser builder. +python "${SCRIPT_DIR}"/neighbor_builder.py --data_root="${DATA_ROOT}" + +# Run the validation split generator. +python "${SCRIPT_DIR}"/generate_validation_splits.py \ + --data_root="${DATA_ROOT}" \ + --output_dir="${DATA_ROOT}/k_fold_splits" diff --git a/ogb_lsc/mag/run_pretrained_eval.sh b/ogb_lsc/mag/run_pretrained_eval.sh new file mode 100755 index 0000000..e83417f --- /dev/null +++ b/ogb_lsc/mag/run_pretrained_eval.sh @@ -0,0 +1,107 @@ +#!/bin/bash +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e +set -x + +while getopts ":r:" opt; do + case ${opt} in + r ) + TASK_ROOT=$OPTARG + ;; + \? ) + echo "Usage: run_pretrained_eval.sh -r " + ;; + : ) + echo "Invalid option: $OPTARG requires an argument" 1>&2 + ;; + esac +done +shift $((OPTIND -1)) + +# Get this script's directory. +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + +echo " +Note this script may take several days to run with default parameters on +a single machine. Reducing EPOCHS_TO_ENSEMBLE from 50 to < 5 should yield +slighly lower validation performance but possibly similar test performance. +" + +echo " +Pre-requisites (See README): +* Python dependencies have been installed. +* pre-processed data is available in the task dir. +* pre-trained model weights are available in the task dir. +" + +read -p "Press enter to continue" + + +# Can set this to "valid" or "test". +# On test the results will be ensembled for 10 models, and a submission file +# will be created. +# On validation, the results will be "gathered" for 10 models. This is because +# each model is trained on the train split + 90% of the validation split, +# and only evaluated on the remaining 10%, such that each validation paper +# is left out from training in exactly one of the 10 models. +SPLIT="test" + +# We used 50 epochs in the submission to sample different subgraphs around +# each central node. +# For each of the 10 models in the ensemble, a single epoch takes about 1 +# 10-25 minutes (depending on the GPU) on the test set, and about 2-3 minutes +# for the corresponding k-fold split of the validation set. +EPOCHS_TO_ENSEMBLE=50 + +DATA_ROOT=${TASK_ROOT}/data/ +MODELS_ROOT=${TASK_ROOT}/models/ +CHECKPOINT_DIR=${TASK_ROOT}/checkpoints/ +OUTPUT_DIR=${TASK_ROOT}/predictions/ + +# We run two seeds for each model of the k=10 k-fold +# first seed group: [100, 101, 102, 103, 104, 105, 106, 107, 108, 109] +# second seed group: [110, 111, 112, 113, 114, 115, 116, 117, 118, 119] +# Thes are the seeds that was selected based on cross validation for each fold. +BEST_SEEDS=(100 111 102 113 104 105 106 107 108 109) +for K_FOLD_INDEX in {0..9}; do + SEED=${BEST_SEEDS[${K_FOLD_INDEX}]} + RESTORE_PATH=${MODELS_ROOT}/k${K_FOLD_INDEX}_seed${SEED} + echo "Running k=${K_FOLD_INDEX} on ${SPLIT} split using ${RESTORE_PATH}" + # This saves the predictions for the K_FOLD_INDEXd'th model in the k-fold to + # "config.experiment_kwargs.config.predictions_dir" for subsequent ensembling + python "${SCRIPT_DIR}"/experiment.py \ + --jaxline_mode="eval" \ + --config="${SCRIPT_DIR}"/config.py \ + --config.one_off_evaluate=True \ + --config.checkpoint_dir=${CHECKPOINT_DIR}/${K_FOLD_INDEX} \ + --config.restore_path=${RESTORE_PATH} \ + --config.experiment_kwargs.config.dataset_kwargs.data_root=${DATA_ROOT} \ + --config.experiment_kwargs.config.dataset_kwargs.k_fold_split_id=${K_FOLD_INDEX} \ + --config.experiment_kwargs.config.num_eval_iterations_to_ensemble=${EPOCHS_TO_ENSEMBLE} \ + --config.experiment_kwargs.config.predictions_dir=${OUTPUT_DIR}/${K_FOLD_INDEX} \ + --config.experiment_kwargs.config.eval.split=${SPLIT} + +done + + +python "${SCRIPT_DIR}"/ensemble_predictions.py \ + --split=${SPLIT} \ + --data_root=${DATA_ROOT} \ + --predictions_path=${OUTPUT_DIR} \ + --output_path=${OUTPUT_DIR} + + +echo "Done" diff --git a/ogb_lsc/mag/run_training.sh b/ogb_lsc/mag/run_training.sh new file mode 100755 index 0000000..887d979 --- /dev/null +++ b/ogb_lsc/mag/run_training.sh @@ -0,0 +1,115 @@ +#!/bin/bash +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e +set -x + +while getopts ":r:" opt; do + case ${opt} in + r ) + TASK_ROOT=$OPTARG + ;; + \? ) + echo "Usage: run_training.sh -r " + ;; + : ) + echo "Invalid option: $OPTARG requires an argument" 1>&2 + ;; + esac +done +shift $((OPTIND -1)) + +# Get this script's directory. +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + +echo " +These scripts are provided for illustrative purposes. It is not practical for +actual training since it only uses a single machine, and likely requires +reducing the batch size and/or model size to fit on a single GPU. + +For the actual submission we used training distributed in different ways: +* We used 4x Cloud TPU v4s to implement batch parallelism, so batch + size is effectively 8 times larger than the values in the config. +* We used separate CPUs to subsample neighborhoods around each central node, to + maximize TPU usage by always having a batch ready for the next iteration. +* We ran the online early-stopping evaluator on a separate machine with an + NVIDIA V100 GPU. +* We run identical replicas of all of the above for each of the 20 models + trained. + +Using those mechanisms, training runs at ~2 steps per second, reaching 500k +steps in under 3 days. +" + +echo " +Pre-requisites (See README): +* Python dependencies have been installed. +* pre-processed data is available in the task dir. +" + +read -p "Press enter to continue" + +# During early stopping seed/selection we ensembled 5 iterations. +EPOCHS_TO_ENSEMBLE=5 + +DATA_ROOT=${TASK_ROOT}/data/ +CHECKPOINT_DIR=${TASK_ROOT}/checkpoints/ + +# We run two seeds for each model of the k=10 k-fold. +BASE_SEED=100 +for SEED_OFFSET in 0 10; do +for K_FOLD_INDEX in {0..9}; do + MODEL_SEED=`expr ${BASE_SEED} + ${SEED_OFFSET} + ${K_FOLD_INDEX}` + SUFFIX=k${K_FOLD_INDEX}_seed${MODEL_SEED} + echo "Running k=${K_FOLD_INDEX} with init seed ${MODEL_SEED}" + + # This runs training (each model is trained on train split + 90% of the + # validation split) with early stopping, storing both "latest" model and + # "best" early-stopped model at `--config.checkpoint_dir`. + + # Models are early stopped based on accuracy of on 10% of the validation data + # (each K_FOLD_INDEX leaves a different 10% of data out from training) left + # out from training. + + # Models are stored at the end of training. Intermediate models can also be + # stored while training by sending a SIGINT signal (Ctrl+C) which will not + # interrupt the training. + + # It is possible to interrupt training using (Ctrl+\) and then continue + # passing the corresponding `--config.restore_path=${RESTORE_PATH}` which is + # stored when interrupting. + + python "${SCRIPT_DIR}"/experiment.py \ + --jaxline_mode="train_eval_multithreaded" \ + --config="${SCRIPT_DIR}"/config.py \ + --config.random_seed=${MODEL_SEED} \ + --config.checkpoint_dir=${CHECKPOINT_DIR}/${SUFFIX} \ + --config.experiment_kwargs.config.dataset_kwargs.data_root=${DATA_ROOT} \ + --config.experiment_kwargs.config.dataset_kwargs.k_fold_split_id=${K_FOLD_INDEX} \ + --config.experiment_kwargs.config.num_eval_iterations_to_ensemble=${EPOCHS_TO_ENSEMBLE} \ + --config.experiment_kwargs.config.eval.split="valid" + +done +done + +# Each of the 20 (two for each value of k) jobs generate paths of the form: +# RESTORE_PATH=${--config.checkpoint_dir}/models/best/step_${STEP}_${TIMESTAMP} +# From each pair, the best one is selected (based on the validation accuracy +# as reported on the logs or in tensorboard events also stored at +# `${--config.checkpoint_dir}/eval`). These can then be used as the +# "RESTORE_PATHS" for `./run_pretrained_eval.sh`. + + +echo "Done" diff --git a/ogb_lsc/mag/schedules.py b/ogb_lsc/mag/schedules.py new file mode 100644 index 0000000..ca605b8 --- /dev/null +++ b/ogb_lsc/mag/schedules.py @@ -0,0 +1,74 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Scheduling utilities.""" + +import jax.numpy as jnp + + +def apply_ema_decay( + ema_value: jnp.ndarray, + current_value: jnp.ndarray, + decay: jnp.ndarray, +) -> jnp.ndarray: + """Implements EMA.""" + return ema_value * decay + current_value * (1 - decay) + + +def ema_decay_schedule( + base_rate: jnp.ndarray, + step: jnp.ndarray, + total_steps: jnp.ndarray, + use_schedule: bool, +) -> jnp.ndarray: + """Anneals decay rate to 1 with cosine schedule.""" + if not use_schedule: + return base_rate + multiplier = _cosine_decay(step, total_steps, 1.) + return 1. - (1. - base_rate) * multiplier + + +def _cosine_decay( + global_step: jnp.ndarray, + max_steps: int, + initial_value: float, +) -> jnp.ndarray: + """Simple implementation of cosine decay from TF1.""" + global_step = jnp.minimum(global_step, max_steps).astype(jnp.float32) + cosine_decay_value = 0.5 * (1 + jnp.cos(jnp.pi * global_step / max_steps)) + decayed_learning_rate = initial_value * cosine_decay_value + return decayed_learning_rate + + +def learning_schedule( + global_step: jnp.ndarray, + base_learning_rate: float, + total_steps: int, + warmup_steps: int, + use_schedule: bool, +) -> float: + """Cosine learning rate scheduler.""" + # Compute LR & Scaled LR + if not use_schedule: + return base_learning_rate + warmup_learning_rate = ( + global_step.astype(jnp.float32) / int(warmup_steps) * + base_learning_rate if warmup_steps > 0 else base_learning_rate) + + # Cosine schedule after warmup. + decay_learning_rate = _cosine_decay(global_step - warmup_steps, + total_steps - warmup_steps, + base_learning_rate) + return jnp.where(global_step < warmup_steps, warmup_learning_rate, + decay_learning_rate) diff --git a/ogb_lsc/mag/split_and_save_indices.py b/ogb_lsc/mag/split_and_save_indices.py new file mode 100644 index 0000000..61e59d8 --- /dev/null +++ b/ogb_lsc/mag/split_and_save_indices.py @@ -0,0 +1,50 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Split and save the train/valid/test indices. + +Usage: + +python3 split_and_save_indices.py --data_root="mag_data" +""" + +import pathlib + +from absl import app +from absl import flags +import numpy as np +import torch + +Path = pathlib.Path + + +FLAGS = flags.FLAGS + +flags.DEFINE_string('data_root', None, 'Data root directory') + + +def main(argv) -> None: + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + mag_directory = Path(FLAGS.data_root) / 'mag240m_kddcup2021' + raw_directory = mag_directory / 'raw' + raw_directory.parent.mkdir(parents=True, exist_ok=True) + splits_dict = torch.load(str(mag_directory / 'split_dict.pt')) + for key, indices in splits_dict.items(): + np.save(str(raw_directory / f'{key}_idx.npy'), indices) + + +if __name__ == '__main__': + flags.mark_flag_as_required('root') + app.run(main) diff --git a/ogb_lsc/mag/sub_sampler.py b/ogb_lsc/mag/sub_sampler.py new file mode 100644 index 0000000..905fc70 --- /dev/null +++ b/ogb_lsc/mag/sub_sampler.py @@ -0,0 +1,250 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for subsampling the MAG dataset.""" + +import collections + +import jraph +import numpy as np + + +def get_or_sample_row(node_id: int, + nb_neighbours: int, + csr_matrix, remove_duplicates: bool): + """Either obtain entire row or a subsampled set of neighbours.""" + if node_id + 1 >= csr_matrix.indptr.shape[0]: + lo = 0 + hi = 0 + else: + lo = csr_matrix.indptr[node_id] + hi = csr_matrix.indptr[node_id + 1] + if lo == hi: # Skip empty neighbourhoods + neighbours = None + elif hi - lo <= nb_neighbours: + neighbours = csr_matrix.indices[lo:hi] + elif hi - lo < 5 * nb_neighbours: # For small surroundings, sample directly + nb_neighbours = min(nb_neighbours, hi - lo) + inds = lo + np.random.choice(hi - lo, size=(nb_neighbours,), replace=False) + neighbours = csr_matrix.indices[inds] + else: # Otherwise, do not slice -- sample indices instead + # To extend GraphSAGE ("uniform w/ replacement"), modify this call + inds = np.random.randint(lo, hi, size=(nb_neighbours,)) + if remove_duplicates: + inds = np.unique(inds) + neighbours = csr_matrix.indices[inds] + return neighbours + + +def get_neighbours(node_id: int, + node_type: int, + neighbour_type: int, + nb_neighbours: int, + remove_duplicates: bool, + author_institution_csr, institution_author_csr, + author_paper_csr, paper_author_csr, + paper_paper_csr, paper_paper_transpose_csr): + """Fetch the edge indices from one node to corresponding neighbour type.""" + if node_type == 0 and neighbour_type == 0: + csr = paper_paper_transpose_csr # Citing + elif node_type == 0 and neighbour_type == 1: + csr = paper_author_csr + elif node_type == 0 and neighbour_type == 3: + csr = paper_paper_csr # Cited + elif node_type == 1 and neighbour_type == 0: + csr = author_paper_csr + elif node_type == 1 and neighbour_type == 2: + csr = author_institution_csr + elif node_type == 2 and neighbour_type == 1: + csr = institution_author_csr + else: + raise ValueError('Non-existent edge type requested') + return get_or_sample_row(node_id, nb_neighbours, csr, remove_duplicates) + + +def get_senders(neighbour_type: int, + sender_index, + paper_features): + """Get the sender features from given neighbours.""" + if neighbour_type == 0 or neighbour_type == 3: + sender_features = paper_features[sender_index] + elif neighbour_type == 1 or neighbour_type == 2: + sender_features = np.zeros((sender_index.shape[0], + paper_features.shape[1])) # Consider averages + else: + raise ValueError('Non-existent node type requested') + return sender_features + + +def make_edge_type_feature(node_type: int, neighbour_type: int): + edge_feats = np.zeros(7) + edge_feats[node_type] = 1.0 + edge_feats[neighbour_type + 3] = 1.0 + return edge_feats + + +def subsample_graph(paper_id: int, + author_institution_csr, + institution_author_csr, + author_paper_csr, + paper_author_csr, + paper_paper_csr, + paper_paper_transpose_csr, + max_nb_neighbours_per_type, + max_nodes=None, + max_edges=None, + paper_years=None, + remove_future_nodes=False, + deduplicate_nodes=False) -> jraph.GraphsTuple: + """Subsample a graph around given paper ID.""" + if paper_years is not None: + root_paper_year = paper_years[paper_id] + else: + root_paper_year = None + # Add the center node as "node-zero" + sub_nodes = [paper_id] + num_nodes_in_subgraph = 1 + num_edges_in_subgraph = 0 + reached_node_budget = False + reached_edge_budget = False + node_and_type_to_index_in_subgraph = dict() + node_and_type_to_index_in_subgraph[(paper_id, 0)] = 0 + # Store all (integer) depths as an additional feature + depths = [0] + types = [0] + sub_edges = [] + sub_senders = [] + sub_receivers = [] + + # Store all unprocessed neighbours + # Each neighbour is stored as a 4-tuple (node_index in original graph, + # node_index in subsampled graph, type, number of hops away from source). + # TYPES: 0: paper, 1: author, 2: institution, 3: paper (for bidirectional) + neighbour_deque = collections.deque([(paper_id, 0, 0, 0)]) + + max_depth = len(max_nb_neighbours_per_type) + + while neighbour_deque and not reached_edge_budget: + left_entry = neighbour_deque.popleft() + node_index, node_index_in_sampled_graph, node_type, node_depth = left_entry + + # Expand from this node, to a node of related type + for neighbour_type in range(4): + if reached_edge_budget: + break # Budget may have been reached in previous type; break here. + nb_neighbours = max_nb_neighbours_per_type[node_depth][node_type][neighbour_type] # pylint:disable=line-too-long + # Only extend if we want to sample further in this edge type + if nb_neighbours > 0: + sampled_neighbors = get_neighbours( + node_index, + node_type, + neighbour_type, + nb_neighbours, + deduplicate_nodes, + author_institution_csr, + institution_author_csr, + author_paper_csr, + paper_author_csr, + paper_paper_csr, + paper_paper_transpose_csr, + ) + + if sampled_neighbors is not None: + if remove_future_nodes and root_paper_year is not None: + if neighbour_type in [0, 3]: + sampled_neighbors = [ + x for x in sampled_neighbors + if paper_years[x] <= root_paper_year + ] + if not sampled_neighbors: + continue + + nb_neighbours = len(sampled_neighbors) + edge_feature = make_edge_type_feature(node_type, neighbour_type) + + for neighbor_original_idx in sampled_neighbors: + # Key into dict of existing nodes using both node id and type. + neighbor_key = (neighbor_original_idx, neighbour_type % 3) + # Get existing idx in subgraph if it exists. + neighbor_subgraph_idx = node_and_type_to_index_in_subgraph.get( + neighbor_key, None) + if (not reached_node_budget and + (not deduplicate_nodes or neighbor_subgraph_idx is None)): + # If it does not exist already, or we are not deduplicating, + # just create a new node and update the dict. + neighbor_subgraph_idx = num_nodes_in_subgraph + node_and_type_to_index_in_subgraph[neighbor_key] = ( + neighbor_subgraph_idx) + num_nodes_in_subgraph += 1 + sub_nodes.append(neighbor_original_idx) + types.append(neighbour_type % 3) + depths.append(node_depth + 1) + if max_nodes is not None and num_nodes_in_subgraph >= max_nodes: + reached_node_budget = True + continue # Move to next neighbor which might already exist. + if node_depth < max_depth - 1: + # If the neighbours are to be further expanded, enqueue them. + # Expand only if the nodes did not already exist. + neighbour_deque.append( + (neighbor_original_idx, neighbor_subgraph_idx, + neighbour_type % 3, node_depth + 1)) + # The neighbor id within graph is now fixed; just add edges. + if neighbor_subgraph_idx is not None: + # Either node existed before or was successfully added. + sub_senders.append(neighbor_subgraph_idx) + sub_receivers.append(node_index_in_sampled_graph) + sub_edges.append(edge_feature) + num_edges_in_subgraph += 1 + if max_edges is not None and num_edges_in_subgraph >= max_edges: + reached_edge_budget = True + break # Break out of adding edges for this neighbor type + + # Stitch the graph together + sub_nodes = np.array(sub_nodes, dtype=np.int32) + + if sub_senders: + sub_senders = np.array(sub_senders, dtype=np.int32) + sub_receivers = np.array(sub_receivers, dtype=np.int32) + sub_edges = np.stack(sub_edges, axis=0) + else: + # Use empty arrays. + sub_senders = np.zeros([0], dtype=np.int32) + sub_receivers = np.zeros([0], dtype=np.int32) + sub_edges = np.zeros([0, 7]) + + # Finally, derive the sizes + sub_n_node = np.array([sub_nodes.shape[0]]) + sub_n_edge = np.array([sub_senders.shape[0]]) + assert sub_nodes.shape[0] == num_nodes_in_subgraph + assert sub_edges.shape[0] == num_edges_in_subgraph + if max_nodes is not None: + assert num_nodes_in_subgraph <= max_nodes + if max_edges is not None: + assert num_edges_in_subgraph <= max_edges + + types = np.array(types) + depths = np.array(depths) + sub_nodes = { + 'index': sub_nodes.astype(np.int32), + 'type': types.astype(np.int16), + 'depth': depths.astype(np.int16), + } + + return jraph.GraphsTuple(nodes=sub_nodes, + edges=sub_edges.astype(np.float16), + senders=sub_senders.astype(np.int32), + receivers=sub_receivers.astype(np.int32), + globals=np.array([0], dtype=np.int16), + n_node=sub_n_node.astype(dtype=np.int32), + n_edge=sub_n_edge.astype(dtype=np.int32)) diff --git a/ogb_lsc/pcq/README.md b/ogb_lsc/pcq/README.md new file mode 100644 index 0000000..0ce56bb --- /dev/null +++ b/ogb_lsc/pcq/README.md @@ -0,0 +1,110 @@ +# DeepMind entry for PCQM4M-LSC + +This repository contains DeepMind's entry to the [PCWM4M-LSC](https://ogb.stanford.edu/kddcup2021/pcqm4m/) (quantum chemistry) +track of the [OGB Large-Scale Challenge](https://ogb.stanford.edu/kddcup2021/) +(OGB-LSC). + +For full details regarding this entry, please see our [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf). + +## DeepMind PCQ Team ("Quantum") + +(in alphabetical order) + +- Ravichandra Addanki +- Peter Battaglia +- David Budden +- Andreea Deac +- Jonathan Godwin +- Alvaro Sanchez-Gonzalez +- Wai Lok Sibon Li +- Jacklynn Stott +- Shantanu Thakoor +- Petar Veličković + + +## Performance + +Our final test set performance was achieved by pooling an ensemble of 20 models +(10 folds x 2 seeds). See [technical report](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf) for details. + +Each model was trained for < 48 hours using 4x Google Cloud TPUv4 and 1x AMD +EPYC 7B12 64-core CPU @2.25GHz. + +Inference takes < 6 hours on 1x NVIDIA V100 16GB GPU and 1x Intel Xeon Gold 6148 +20-core CPU @2.40GHz. + +# Running our model + +## Setup + +You can set up Python virtual environment (you might need to install the +`python3-venv` package first) with all needed dependencies inside the forked +`deepmind_research` repository using: + +```bash +python3 -m venv /tmp/pcq_venv +source /tmp/pcq_venv/bin/activate +pip3 install --upgrade pip setuptools wheel +pip3 install -r ogb_lsc/pcq/requirements.txt +``` + +## Downloading data and model weights + +All necessary data and pre-trained model weights can be downloaded by running +the following command. +This downloads about ~ 150 GB worth of model checkpoints. + +```bash +python download_required_pcq_data.py --data_root=${HOME}/data/ +``` + +## Generating Pre-processed features + +All the additional features used in training +(k-fold splits and conformer position features) can be generated by running. +```bash +/bin/bash run_preprocessing.sh -r ${HOME}/data/pcq/ +``` + +## Reproducing our final results + +We have provided pre-trained weights of our final submission for convenience. To +reproduce our final results, please run `run_pretrained_eval.sh` as follows. + +```bash +/bin/bash run_pretrained_eval.sh -r ${HOME}/data/pcq/ +``` + +## Retraining our model + +Disclaimer: This script is provided for illustrative purposes. It is not +practical for actual training since it only uses a single machine, and likely +requires reducing the batch size and/or model size to fit on a single GPU. + +If you still want to train a model, please run `run_training.sh`. To simply +validate that the code is running correctly on your hardware setup, consider +setting `debug=True` in `config.py`, which trains a smaller model. + + +```bash +/bin/bash run_training.sh -r ${HOME}/data/pcq/ +``` + + +# Citation + +To cite this work (together with our MAG240M-LSC entry): + +```latex +@article{deepmind2021ogb, + author = {Ravichandra Addanki and Peter Battaglia and David Budden and Andreea + Deac and Jonathan Godwin and Thomas Keck and Wai Lok Sibon Li and Alvaro + Sanchez-Gonzalez and Jacklynn Stott and Shantanu Thakoor and Petar + Veli\v{c}kovi\'{c}}, + title = {Large-scale graph representation learning with very deep GNNs and + self-supervision}, + year = {2021}, +} +``` + +Our technical report can be found [here](https://storage.googleapis.com/deepmind-ogb-lsc/reports/OGB_LSC_Tech_Report.pdf). diff --git a/ogb_lsc/pcq/batching_utils.py b/ogb_lsc/pcq/batching_utils.py new file mode 100644 index 0000000..b06dcbd --- /dev/null +++ b/ogb_lsc/pcq/batching_utils.py @@ -0,0 +1,135 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic batching utilities.""" + +from typing import Generator, Iterator, Sequence, Tuple + +import jax.tree_util as tree +import jraph +import numpy as np + +_NUMBER_FIELDS = ("n_node", "n_edge", "n_graph") + + +def dynamically_batch(graphs_tuple_iterator: Iterator[jraph.GraphsTuple], + n_node: int, n_edge: int, + n_graph: int) -> Generator[jraph.GraphsTuple, None, None]: + """Dynamically batches trees with `jraph.GraphsTuples` to `graph_batch_size`. + + Elements of the `graphs_tuple_iterator` will be incrementally added to a batch + until the limits defined by `n_node`, `n_edge` and `n_graph` are reached. This + means each element yielded by this generator + + For situations where you have variable sized data, it"s useful to be able to + have variable sized batches. This is especially the case if you have a loss + defined on the variable shaped element (for example, nodes in a graph). + + Args: + graphs_tuple_iterator: An iterator of `jraph.GraphsTuples`. + n_node: The maximum number of nodes in a batch. + n_edge: The maximum number of edges in a batch. + n_graph: The maximum number of graphs in a batch. + + Yields: + A `jraph.GraphsTuple` batch of graphs. + + Raises: + ValueError: if the number of graphs is < 2. + RuntimeError: if the `graphs_tuple_iterator` contains elements which are not + `jraph.GraphsTuple`s. + RuntimeError: if a graph is found which is larger than the batch size. + """ + if n_graph < 2: + raise ValueError("The number of graphs in a batch size must be greater or " + f"equal to `2` for padding with graphs, got {n_graph}.") + valid_batch_size = (n_node - 1, n_edge, n_graph - 1) + accumulated_graphs = [] + num_accumulated_nodes = 0 + num_accumulated_edges = 0 + num_accumulated_graphs = 0 + for element in graphs_tuple_iterator: + element_nodes, element_edges, element_graphs = _get_graph_size(element) + if _is_over_batch_size(element, valid_batch_size): + graph_size = element_nodes, element_edges, element_graphs + graph_size = {k: v for k, v in zip(_NUMBER_FIELDS, graph_size)} + batch_size = {k: v for k, v in zip(_NUMBER_FIELDS, valid_batch_size)} + raise RuntimeError("Found graph bigger than batch size. Valid Batch " + f"Size: {batch_size}, Graph Size: {graph_size}") + + if not accumulated_graphs: + # If this is the first element of the batch, set it and continue. + accumulated_graphs = [element] + num_accumulated_nodes = element_nodes + num_accumulated_edges = element_edges + num_accumulated_graphs = element_graphs + continue + else: + # Otherwise check if there is space for the graph in the batch: + if ((num_accumulated_graphs + element_graphs > n_graph - 1) or + (num_accumulated_nodes + element_nodes > n_node - 1) or + (num_accumulated_edges + element_edges > n_edge)): + # If there is, add it to the batch + batched_graph = _batch_np(accumulated_graphs) + yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph) + accumulated_graphs = [element] + num_accumulated_nodes = element_nodes + num_accumulated_edges = element_edges + num_accumulated_graphs = element_graphs + else: + # Otherwise, return the old batch and start a new batch. + accumulated_graphs.append(element) + num_accumulated_nodes += element_nodes + num_accumulated_edges += element_edges + num_accumulated_graphs += element_graphs + + # We may still have data in batched graph. + if accumulated_graphs: + batched_graph = _batch_np(accumulated_graphs) + yield jraph.pad_with_graphs(batched_graph, n_node, n_edge, n_graph) + + +def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple: + # Calculates offsets for sender and receiver arrays, caused by concatenating + # the nodes arrays. + offsets = np.cumsum(np.array([0] + [np.sum(g.n_node) for g in graphs[:-1]])) + + def _map_concat(nests): + concat = lambda *args: np.concatenate(args) + return tree.tree_multimap(concat, *nests) + + return jraph.GraphsTuple( + n_node=np.concatenate([g.n_node for g in graphs]), + n_edge=np.concatenate([g.n_edge for g in graphs]), + nodes=_map_concat([g.nodes for g in graphs]), + edges=_map_concat([g.edges for g in graphs]), + globals=_map_concat([g.globals for g in graphs]), + senders=np.concatenate([g.senders + o for g, o in zip(graphs, offsets)]), + receivers=np.concatenate( + [g.receivers + o for g, o in zip(graphs, offsets)])) + + +def _get_graph_size(graph: jraph.GraphsTuple) -> Tuple[int, int, int]: + n_node = np.sum(graph.n_node) + n_edge = len(graph.senders) + n_graph = len(graph.n_node) + return n_node, n_edge, n_graph + + +def _is_over_batch_size( + graph: jraph.GraphsTuple, + graph_batch_size: Tuple[int, int, int], +) -> bool: + graph_size = _get_graph_size(graph) + return any([x > y for x, y in zip(graph_size, graph_batch_size)]) diff --git a/ogb_lsc/pcq/config.py b/ogb_lsc/pcq/config.py new file mode 100644 index 0000000..9e8d316 --- /dev/null +++ b/ogb_lsc/pcq/config.py @@ -0,0 +1,130 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experiment config for PCQM4M-LSC entry.""" + +from jaxline import base_config +from ml_collections import config_dict + + +def get_config(debug: bool = False) -> config_dict.ConfigDict: + """Get Jaxline experiment config.""" + config = base_config.get_base_config() + # E.g. '/data/pretrained_models/k0_seed100' (and set k_fold_split_id=0, below) + config.restore_path = config_dict.placeholder(str) + + training_batch_size = 64 + eval_batch_size = 64 + + ## Experiment config. + loss_config_name = 'RegressionLossConfig' + loss_kwargs = dict( + exponent=1., # 2 for l2 loss, 1 for l1 loss, etc... + ) + + dataset_config = dict( + data_root=config_dict.placeholder(str), + augment_with_random_mirror_symmetry=True, + k_fold_split_id=config_dict.placeholder(int), + num_k_fold_splits=config_dict.placeholder(int), + # Options: "in" or "out". + # Filter=in would keep the samples with nans in the conformer features. + # Filter=out would keep the samples with no NaNs anywhere in the conformer + # features. + filter_in_or_out_samples_with_nans_in_conformers=( + config_dict.placeholder(str)), + cached_conformers_file=config_dict.placeholder(str)) + + model_config = dict( + mlp_hidden_size=512, + mlp_layers=2, + latent_size=512, + use_layer_norm=False, + num_message_passing_steps=32, + shared_message_passing_weights=False, + mask_padding_graph_at_every_step=True, + loss_config_name=loss_config_name, + loss_kwargs=loss_kwargs, + processor_mode='resnet', + global_reducer='sum', + node_reducer='sum', + dropedge_rate=0.1, + dropnode_rate=0.1, + aux_multiplier=0.1, + add_relative_distance=True, + add_relative_displacement=True, + add_absolute_positions=False, + position_normalization=2., + relative_displacement_normalization=1., + ignore_globals=False, + ignore_globals_from_final_layer_for_predictions=True, + ) + + if debug: + # Make network smaller. + model_config.update(dict( + mlp_hidden_size=32, + mlp_layers=1, + latent_size=32, + num_message_passing_steps=1)) + + config.experiment_kwargs = config_dict.ConfigDict( + dict( + config=dict( + debug=debug, + predictions_dir=config_dict.placeholder(str), + ema=True, + ema_decay=0.9999, + sample_random=0.05, + optimizer=dict( + name='adam', + optimizer_kwargs=dict(b1=.9, b2=.95), + lr_schedule=dict( + warmup_steps=int(5e4), + decay_steps=int(5e5), + init_value=1e-5, + peak_value=1e-4, + end_value=0., + ), + ), + model=model_config, + dataset_config=dataset_config, + # As a rule of thumb, use the following statistics: + # Avg. # nodes in graph: 16. + # Avg. # edges in graph: 40. + training=dict( + dynamic_batch_size={ + 'n_node': 256 if debug else 16 * training_batch_size, + 'n_edge': 512 if debug else 40 * training_batch_size, + 'n_graph': 2 if debug else training_batch_size, + },), + evaluation=dict( + split='valid', + dynamic_batch_size=dict( + n_node=256 if debug else 16 * eval_batch_size, + n_edge=512 if debug else 40 * eval_batch_size, + n_graph=2 if debug else eval_batch_size, + ))))) + + ## Training loop config. + config.training_steps = int(5e6) + config.checkpoint_dir = '/tmp/checkpoint/pcq/' + config.train_checkpoint_all_hosts = False + config.save_checkpoint_interval = 300 + config.log_train_data_interval = 60 + config.log_tensors_interval = 60 + config.best_model_eval_metric = 'mae' + config.best_model_eval_metric_higher_is_better = False + + return config diff --git a/ogb_lsc/pcq/conformer_utils.py b/ogb_lsc/pcq/conformer_utils.py new file mode 100644 index 0000000..3217a4b --- /dev/null +++ b/ogb_lsc/pcq/conformer_utils.py @@ -0,0 +1,324 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Conformer utilities.""" + +import copy +from typing import List, Optional +from absl import logging +import numpy as np +import rdkit +from rdkit import Chem +from rdkit.Chem import AllChem +import tensorflow.compat.v2 as tf + + +def generate_conformers( + molecule: Chem.rdchem.Mol, + max_num_conformers: int, + *, + random_seed: int = -1, + prune_rms_thresh: float = -1.0, + max_iter: int = -1, + fallback_to_random: bool = False, +) -> Chem.rdchem.Mol: + """Generates conformers for a given molecule. + + Args: + molecule: molecular representation of the compound. + max_num_conformers: maximum number of conformers to generate. If pruning is + done, the returned number of conformers is not guaranteed to match + max_num_conformers. + random_seed: random seed to use for conformer generation. + prune_rms_thresh: RMSD threshold which allows to prune conformers that are + too similar. + max_iter: Maximum number of iterations to perform when optimising MMFF force + field. If set to <= 0, energy optimisation is not performed. + fallback_to_random: if conformers cannot be obtained, use random coordinates + to initialise. + + Returns: + Copy of a `molecule` with added hydrogens. The returned molecule contains + force field-optimised conformers. The number of conformers is guaranteed to + be <= max_num_conformers. + """ + mol = copy.deepcopy(molecule) + mol = Chem.AddHs(mol) + mol = _embed_conformers( + mol, + max_num_conformers, + random_seed, + prune_rms_thresh, + fallback_to_random, + use_random=False) + + if max_iter > 0: + mol_with_conformers = _minimize_by_mmff(mol, max_iter) + if mol_with_conformers is None: + mol_with_conformers = _minimize_by_uff(mol, max_iter) + else: + mol_with_conformers = mol + # Aligns conformations in a molecule to each other using the first + # conformation as the reference. + AllChem.AlignMolConformers(mol_with_conformers) + + # We remove hydrogens to keep the number of atoms consistent with the graph + # nodes. + mol_with_conformers = Chem.RemoveHs(mol_with_conformers) + + return mol_with_conformers + + +def atom_to_feature_vector( + atom: rdkit.Chem.rdchem.Atom, + conformer: Optional[np.ndarray] = None, +) -> List[float]: + """Converts rdkit atom object to feature list of indices. + + Args: + atom: rdkit atom object. + conformer: Generated conformers. Returns -1 values if set to None. + + Returns: + List containing positions (x, y, z) of each atom from the conformer. + """ + if conformer: + pos = conformer.GetAtomPosition(atom.GetIdx()) + return [pos.x, pos.y, pos.z] + return [np.nan, np.nan, np.nan] + + +def compute_conformer(smile: str, max_iter: int = -1) -> np.ndarray: + """Computes conformer. + + Args: + smile: Smile string. + max_iter: Maximum number of iterations to perform when optimising MMFF force + field. If set to <= 0, energy optimisation is not performed. + + Returns: + A tuple containing index, fingerprint and conformer. + Raises: + RuntimeError: If unable to convert smile string to RDKit mol. + """ + mol = rdkit.Chem.MolFromSmiles(smile) + if not mol: + raise RuntimeError('Unable to convert smile to molecule: %s' % smile) + conformer_failed = False + try: + mol = generate_conformers( + mol, + max_num_conformers=1, + random_seed=45, + prune_rms_thresh=0.01, + max_iter=max_iter) + except IOError as e: + logging.exception('Failed to generate conformers for %s . IOError %s.', + smile, e) + conformer_failed = True + except ValueError: + logging.error('Failed to generate conformers for %s . ValueError', smile) + conformer_failed = True + except: # pylint: disable=bare-except + logging.error('Failed to generate conformers for %s.', smile) + conformer_failed = True + + atom_features_list = [] + conformer = None if conformer_failed else list(mol.GetConformers())[0] + for atom in mol.GetAtoms(): + atom_features_list.append(atom_to_feature_vector(atom, conformer)) + conformer_features = np.array(atom_features_list, dtype=np.float32) + return conformer_features + + +def get_random_rotation_matrix(include_mirror_symmetry: bool) -> tf.Tensor: + """Returns a single random rotation matrix.""" + rotation_matrix = _get_random_rotation_3d() + if include_mirror_symmetry: + random_mirror_symmetry = _get_random_mirror_symmetry() + rotation_matrix = tf.matmul(rotation_matrix, random_mirror_symmetry) + + return rotation_matrix + + +def rotate(vectors: tf.Tensor, rotation_matrix: tf.Tensor) -> tf.Tensor: + """Batch of vectors on a single rotation matrix.""" + return tf.matmul(vectors, rotation_matrix) + + +def _embed_conformers( + molecule: Chem.rdchem.Mol, + max_num_conformers: int, + random_seed: int, + prune_rms_thresh: float, + fallback_to_random: bool, + *, + use_random: bool = False, +) -> Chem.rdchem.Mol: + """Embeds conformers into a copy of a molecule. + + If random coordinates allowed, tries not to use random coordinates at first, + and uses random only if fails. + + Args: + molecule: molecular representation of the compound. + max_num_conformers: maximum number of conformers to generate. If pruning is + done, the returned number of conformers is not guaranteed to match + max_num_conformers. + random_seed: random seed to use for conformer generation. + prune_rms_thresh: RMSD threshold which allows to prune conformers that are + too similar. + fallback_to_random: if conformers cannot be obtained, use random coordinates + to initialise. + *: + use_random: Use random coordinates. Shouldn't be set by any caller except + this function itself. + + Returns: + A copy of a molecule with embedded conformers. + + Raises: + ValueError: if conformers cannot be obtained for a given molecule. + """ + mol = copy.deepcopy(molecule) + + # Obtains parameters for conformer generation. + # In particular, ETKDG is experimental-torsion basic knowledge distance + # geometry, which allows to randomly generate an initial conformation that + # satisfies various geometric constraints such as lower and upper bounds on + # the distances between atoms. + params = AllChem.ETKDGv3() + + params.randomSeed = random_seed + params.pruneRmsThresh = prune_rms_thresh + params.numThreads = -1 + params.useRandomCoords = use_random + + conf_ids = AllChem.EmbedMultipleConfs(mol, max_num_conformers, params) + + if not conf_ids: + if not fallback_to_random or use_random: + raise ValueError('Cant get conformers') + return _embed_conformers( + mol, + max_num_conformers, + random_seed, + prune_rms_thresh, + fallback_to_random, + use_random=True) + return mol + + +def _minimize_by_mmff( + molecule: Chem.rdchem.Mol, + max_iter: int, +) -> Optional[Chem.rdchem.Mol]: + """Minimizes forcefield for conformers using MMFF algorithm. + + Args: + molecule: a datastructure containing conformers. + max_iter: number of maximum iterations to use when optimising force field. + + Returns: + A copy of a `molecule` containing optimised conformers; or None if MMFF + cannot be performed. + """ + molecule_props = AllChem.MMFFGetMoleculeProperties(molecule) + if molecule_props is None: + return None + + mol = copy.deepcopy(molecule) + for conf_id in range(mol.GetNumConformers()): + ff = AllChem.MMFFGetMoleculeForceField( + mol, molecule_props, confId=conf_id, ignoreInterfragInteractions=False) + ff.Initialize() + # minimises a conformer within a mol in place. + ff.Minimize(max_iter) + return mol + + +def _minimize_by_uff( + molecule: Chem.rdchem.Mol, + max_iter: int, +) -> Chem.rdchem.Mol: + """Minimizes forcefield for conformers using UFF algorithm. + + Args: + molecule: a datastructure containing conformers. + max_iter: number of maximum iterations to use when optimising force field. + + Returns: + A copy of a `molecule` containing optimised conformers. + """ + mol = copy.deepcopy(molecule) + conf_ids = range(mol.GetNumConformers()) + for conf_id in conf_ids: + ff = AllChem.UFFGetMoleculeForceField(mol, confId=conf_id) + ff.Initialize() + # minimises a conformer within a mol in place. + ff.Minimize(max_iter) + return mol + + +def _get_symmetry_rotation_matrix(sign: tf.Tensor) -> tf.Tensor: + """Returns the 2d/3d matrix for mirror symmetry.""" + zero = tf.zeros_like(sign) + one = tf.ones_like(sign) + # pylint: disable=bad-whitespace,bad-continuation + rot = [sign, zero, zero, + zero, one, zero, + zero, zero, one] + # pylint: enable=bad-whitespace,bad-continuation + shape = (3, 3) + rot = tf.stack(rot, axis=-1) + rot = tf.reshape(rot, shape) + return rot + + +def _quaternion_to_rotation_matrix(quaternion: tf.Tensor) -> tf.Tensor: + """Converts a batch of quaternions to a batch of rotation matrices.""" + q0 = quaternion[0] + q1 = quaternion[1] + q2 = quaternion[2] + q3 = quaternion[3] + + r00 = 2 * (q0 * q0 + q1 * q1) - 1 + r01 = 2 * (q1 * q2 - q0 * q3) + r02 = 2 * (q1 * q3 + q0 * q2) + r10 = 2 * (q1 * q2 + q0 * q3) + r11 = 2 * (q0 * q0 + q2 * q2) - 1 + r12 = 2 * (q2 * q3 - q0 * q1) + r20 = 2 * (q1 * q3 - q0 * q2) + r21 = 2 * (q2 * q3 + q0 * q1) + r22 = 2 * (q0 * q0 + q3 * q3) - 1 + + matrix = tf.stack([r00, r01, r02, + r10, r11, r12, + r20, r21, r22], axis=-1) + return tf.reshape(matrix, [3, 3]) + + +def _get_random_rotation_3d() -> tf.Tensor: + random_quaternions = tf.random.normal( + shape=[4], dtype=tf.float32) + random_quaternions /= tf.linalg.norm( + random_quaternions, axis=-1, keepdims=True) + return _quaternion_to_rotation_matrix(random_quaternions) + + +def _get_random_mirror_symmetry() -> tf.Tensor: + random_0_1 = tf.random.uniform( + shape=(), minval=0, maxval=2, dtype=tf.int32) + random_signs = tf.cast((2 * random_0_1) - 1, tf.float32) + return _get_symmetry_rotation_matrix(random_signs) diff --git a/ogb_lsc/pcq/dataset_utils.py b/ogb_lsc/pcq/dataset_utils.py new file mode 100644 index 0000000..8bfdfa1 --- /dev/null +++ b/ogb_lsc/pcq/dataset_utils.py @@ -0,0 +1,391 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dataset utilities.""" + +from typing import List, Optional + +import jax +import jraph +from ml_collections import config_dict +import numpy as np +from ogb import utils +from ogb.utils import features +import tensorflow.compat.v2 as tf +import tensorflow_datasets as tfds +import tree + +# pylint: disable=g-bad-import-order +# pytype: disable=import-error +import batching_utils +import conformer_utils +import datasets + + +def build_dataset_iterator( + data_root: str, + split: str, + dynamic_batch_size_config: config_dict.ConfigDict, + sample_random: float, + cached_conformers_file: str, + debug: bool = False, + is_training: bool = True, + augment_with_random_mirror_symmetry: bool = False, + positions_noise_std: Optional[float] = None, + k_fold_split_id: Optional[int] = None, + num_k_fold_splits: Optional[int] = None, + filter_in_or_out_samples_with_nans_in_conformers: Optional[str] = None, +): + """Returns an iterator over Batches from the dataset.""" + if debug: + max_items_to_read_from_dataset = 10 + prefetch_buffer_size = 1 + shuffle_buffer_size = 1 + else: + max_items_to_read_from_dataset = -1 # < 0 means no limit. + prefetch_buffer_size = 64 + # It can take a while to fill the shuffle buffer with k fold splits. + shuffle_buffer_size = 128 if k_fold_split_id is None else int(1e6) + + num_local_devices = jax.local_device_count() + + # Load all smile strings. + indices, smiles, labels = _load_smiles( + data_root, + split, + k_fold_split_id=k_fold_split_id, + num_k_fold_splits=num_k_fold_splits) + if debug: + indices = indices[:100] + smiles = smiles[:100] + labels = labels[:100] + # Generate all conformer features from smile strings ahead of time. + # This gives us a boost from multi-parallelism as opposed to doing it + # online. + conformers = _load_conformers(indices, smiles, cached_conformers_file) + + data_generator = ( + lambda: _get_pcq_graph_generator(indices, smiles, labels, conformers)) + # Create a dataset yielding graphs from smile strings. + example = next(data_generator()) + signature_from_example = tree.map_structure(_numpy_to_tensor_spec, example) + ds = tf.data.Dataset.from_generator( + data_generator, output_signature=signature_from_example) + + ds = ds.take(max_items_to_read_from_dataset) + ds = ds.cache() + if is_training: + ds = ds.shuffle(shuffle_buffer_size) + + # Apply transformations. + def map_fn(graph, conformer_positions): + graph = _maybe_one_hot_atoms_with_noise( + graph, is_training=is_training, sample_random=sample_random) + # Add conformer features. + graph = _add_conformer_features( + graph, + conformer_positions, + augment_with_random_mirror_symmetry=augment_with_random_mirror_symmetry, + noise_std=positions_noise_std, + is_training=is_training, + ) + return _downcast_ints(graph) + + ds = ds.map(map_fn, num_parallel_calls=tf.data.AUTOTUNE) + if filter_in_or_out_samples_with_nans_in_conformers: + if filter_in_or_out_samples_with_nans_in_conformers not in ("in", "out"): + raise ValueError( + "Unknown value specified for the argument " + "`filter_in_or_out_samples_with_nans_in_conformers`: %s" % + filter_in_or_out_samples_with_nans_in_conformers) + + filter_fn = _get_conformer_filter( + with_nans=(filter_in_or_out_samples_with_nans_in_conformers == "in")) + ds = ds.filter(filter_fn) + + if is_training: + ds = ds.shard(jax.process_count(), jax.process_index()) + ds = ds.repeat() + + ds = ds.prefetch(prefetch_buffer_size) + it = tfds.as_numpy(ds) + + # Dynamic batching. + batched_gen = batching_utils.dynamically_batch( + it, + n_node=dynamic_batch_size_config.n_node + 1, + n_edge=dynamic_batch_size_config.n_edge, + n_graph=dynamic_batch_size_config.n_graph + 1, + ) + + if is_training: + # Stack `num_local_devices` of batches together for pmap updates. + batch_size = num_local_devices + + def _batch(l): + assert l + return tree.map_structure(lambda *l: np.stack(l, axis=0), *l) + + def batcher_fn(): + batch = [] + for sample in batched_gen: + batch.append(sample) + if len(batch) == batch_size: + yield _batch(batch) + batch = [] + if batch: + yield _batch(batch) + + for sample in batcher_fn(): + yield sample + else: + for sample in batched_gen: + yield sample + + +def _get_conformer_filter(with_nans: bool): + """Selects a conformer filter to apply. + + Args: + with_nans: Filter only selects samples with NaNs in conformer features. + Else, selects samples without any NaNs in conformer features. + + Returns: + A function that can be used with tf.data.Dataset.filter(). + + Raises: + ValueError: + If the input graph to the filter has no conformer features to filter. + """ + + def _filter(graph: jraph.GraphsTuple) -> tf.Tensor: + + if ("positions" not in graph.nodes) or ( + "positions_targets" not in graph.nodes) or ( + "positions_nan_mask" not in graph.globals): + raise ValueError("Conformer features not available to filter.") + + any_nan = tf.logical_not(tf.squeeze(graph.globals["positions_nan_mask"])) + return any_nan if with_nans else tf.logical_not(any_nan) + + return _filter + + +def _numpy_to_tensor_spec(arr: np.ndarray) -> tf.TensorSpec: + if not isinstance(arr, np.ndarray): + return tf.TensorSpec([], + dtype=tf.int32 if isinstance(arr, int) else tf.float32) + elif arr.shape: + return tf.TensorSpec((None,) + arr.shape[1:], arr.dtype) + else: + return tf.TensorSpec([], arr.dtype) + + +def _sample_uniform_categorical(num: int, size: int) -> tf.Tensor: + return tf.random.categorical(tf.math.log([[1 / size] * size]), num)[0] + + +@jax.curry(jax.tree_map) +def _downcast_ints(x): + if x.dtype == tf.int64: + return tf.cast(x, tf.int32) + return x + + +def _one_hot_atoms(atoms: tf.Tensor) -> tf.Tensor: + vocab_sizes = features.get_atom_feature_dims() + one_hots = [] + for i in range(atoms.shape[1]): + one_hots.append(tf.one_hot(atoms[:, i], vocab_sizes[i], dtype=tf.float32)) + return tf.concat(one_hots, axis=-1) + + +def _sample_one_hot_atoms(atoms: tf.Tensor) -> tf.Tensor: + vocab_sizes = features.get_atom_feature_dims() + one_hots = [] + num_atoms = tf.shape(atoms)[0] + for i in range(atoms.shape[1]): + sampled_category = _sample_uniform_categorical(num_atoms, vocab_sizes[i]) + one_hots.append( + tf.one_hot(sampled_category, vocab_sizes[i], dtype=tf.float32)) + return tf.concat(one_hots, axis=-1) + + +def _one_hot_bonds(bonds: tf.Tensor) -> tf.Tensor: + vocab_sizes = features.get_bond_feature_dims() + one_hots = [] + for i in range(bonds.shape[1]): + one_hots.append(tf.one_hot(bonds[:, i], vocab_sizes[i], dtype=tf.float32)) + return tf.concat(one_hots, axis=-1) + + +def _sample_one_hot_bonds(bonds: tf.Tensor) -> tf.Tensor: + vocab_sizes = features.get_bond_feature_dims() + one_hots = [] + num_bonds = tf.shape(bonds)[0] + for i in range(bonds.shape[1]): + sampled_category = _sample_uniform_categorical(num_bonds, vocab_sizes[i]) + one_hots.append( + tf.one_hot(sampled_category, vocab_sizes[i], dtype=tf.float32)) + return tf.concat(one_hots, axis=-1) + + +def _maybe_one_hot_atoms_with_noise( + x, + is_training: bool, + sample_random: float, +): + """One hot atoms with noise.""" + gt_nodes = _one_hot_atoms(x.nodes) + gt_edges = _one_hot_bonds(x.edges) + if is_training: + num_nodes = tf.shape(x.nodes)[0] + sample_node_or_not = tf.random.uniform([num_nodes], + maxval=1) < sample_random + nodes = tf.where( + tf.expand_dims(sample_node_or_not, axis=-1), + _sample_one_hot_atoms(x.nodes), gt_nodes) + num_edges = tf.shape(x.edges)[0] + sample_edges_or_not = tf.random.uniform([num_edges], + maxval=1) < sample_random + edges = tf.where( + tf.expand_dims(sample_edges_or_not, axis=-1), + _sample_one_hot_bonds(x.edges), gt_edges) + else: + nodes = gt_nodes + edges = gt_edges + return x._replace( + nodes={ + "atom_one_hots_targets": gt_nodes, + "atom_one_hots": nodes, + }, + edges={ + "bond_one_hots_targets": gt_edges, + "bond_one_hots": edges + }) + + +def _load_smiles( + data_root: str, + split: str, + k_fold_split_id: int, + num_k_fold_splits: int, +): + """Loads smiles trings for the input split.""" + + if split == "test" or k_fold_split_id is None: + indices = datasets.load_splits()[split] + elif split == "train": + indices = datasets.load_all_except_kth_fold_indices( + data_root, k_fold_split_id, num_k_fold_splits) + else: + assert split == "valid" + indices = datasets.load_kth_fold_indices(data_root, k_fold_split_id) + + smiles_and_labels = datasets.load_smile_strings(with_labels=True) + smiles, labels = list(zip(*smiles_and_labels)) + return indices, [smiles[i] for i in indices], [labels[i] for i in indices] + + +def _convert_ogb_graph_to_graphs_tuple(ogb_graph): + """Converts an OGB Graph to a GraphsTuple.""" + senders = ogb_graph["edge_index"][0] + receivers = ogb_graph["edge_index"][1] + edges = ogb_graph["edge_feat"] + nodes = ogb_graph["node_feat"] + n_node = np.array([ogb_graph["num_nodes"]]) + n_edge = np.array([len(senders)]) + graph = jraph.GraphsTuple( + nodes=nodes, + edges=edges, + senders=senders, + receivers=receivers, + n_node=n_node, + n_edge=n_edge, + globals=None) + return tree.map_structure(lambda x: x if x is not None else np.array(0.), + graph) + + +def _load_conformers(indices: List[int], + smiles: List[str], + cached_conformers_file: str): + """Loads conformers.""" + smile_to_conformer = datasets.load_cached_conformers(cached_conformers_file) + conformers = [] + for graph_idx, smile in zip(indices, smiles): + del graph_idx # Unused. + if smile not in smile_to_conformer: + raise KeyError("Cache did not have conformer entry for the smile %s" % + str(smile)) + conformers.append(dict(conformer=smile_to_conformer[smile])) + return conformers + + +def _add_conformer_features( + graph, + conformer_features, + augment_with_random_mirror_symmetry: bool, + noise_std: float, + is_training: bool, +): + """Adds conformer features.""" + if not isinstance(graph.nodes, dict): + raise ValueError("Expected a dict type for `graph.nodes`.") + # Remove mean position to center around a canonical origin. + positions = conformer_features["conformer"] + # NaN's appear in ~0.13% of training, 0.104% of validation and 0.16% of test + # nodes. + # See this colab: http://shortn/_6UcuosxY7x. + nan_mask = tf.reduce_any(tf.math.is_nan(positions)) + + positions = tf.where(nan_mask, tf.constant(0., positions.dtype), positions) + positions -= tf.reduce_mean(positions, axis=0, keepdims=True) + + # Optionally augment with a random rotation. + if is_training: + rot_mat = conformer_utils.get_random_rotation_matrix( + augment_with_random_mirror_symmetry) + positions = conformer_utils.rotate(positions, rot_mat) + positions_targets = positions + + # Optionally add noise to the positions. + if noise_std and is_training: + positions = tf.random.normal(tf.shape(positions), positions, noise_std) + + return graph._replace( + nodes=dict( + positions=positions, + positions_targets=positions_targets, + **graph.nodes), + globals={ + "positions_nan_mask": + tf.expand_dims(tf.logical_not(nan_mask), axis=0), + **(graph.globals if isinstance(graph.globals, dict) else {}) + }) + + +def _get_pcq_graph_generator(indices, smiles, labels, conformers): + """Returns a generator to yield graph.""" + for idx, smile, conformer_positions, label in zip(indices, smiles, conformers, + labels): + graph = utils.smiles2graph(smile) + graph = _convert_ogb_graph_to_graphs_tuple(graph) + graph = graph._replace( + globals={ + "target": np.array([label], dtype=np.float32), + "graph_index": np.array([idx], dtype=np.int32), + **(graph.globals if isinstance(graph.globals, dict) else {}) + }) + yield graph, conformer_positions diff --git a/ogb_lsc/pcq/datasets.py b/ogb_lsc/pcq/datasets.py new file mode 100644 index 0000000..a307777 --- /dev/null +++ b/ogb_lsc/pcq/datasets.py @@ -0,0 +1,83 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PCQM4M-LSC datasets.""" + +import functools +import pickle +from typing import Dict, List, Tuple, Union + +import numpy as np +from ogb import lsc + +NUM_VALID_SAMPLES = 380_670 +NUM_TEST_SAMPLES = 377_423 + +NORMALIZE_TARGET_MEAN = 5.690944545356371 +NORMALIZE_TARGET_STD = 1.1561347795107815 + + +def load_splits() -> Dict[str, List[int]]: + """Loads dataset splits.""" + dataset = _get_pcq_dataset(only_smiles=True) + return dataset.get_idx_split() + + +def load_kth_fold_indices(data_root: str, k_fold_split_id: int) -> List[int]: + """Loads k-th fold indices.""" + fname = f"{data_root}/k_fold_splits/{k_fold_split_id}.pkl" + return list(map(int, _load_pickle(fname))) + + +def load_all_except_kth_fold_indices(data_root: str, k_fold_split_id: int, + num_k_fold_splits: int) -> List[int]: + """Loads indices except for the kth fold.""" + if k_fold_split_id is None: + raise ValueError("Expected integer value for `k_fold_split_id`.") + indices = [] + for index in range(num_k_fold_splits): + if index != k_fold_split_id: + indices += load_kth_fold_indices(data_root, index) + return indices + + +def load_smile_strings( + with_labels=False) -> List[Union[str, Tuple[str, np.ndarray]]]: + """Loads the smile strings in the PCQ dataset.""" + dataset = _get_pcq_dataset(only_smiles=True) + smiles = [] + for i in range(len(dataset)): + smile, label = dataset[i] + if with_labels: + smiles.append((smile, label)) + else: + smiles.append(smile) + + return smiles + + +@functools.lru_cache() +def load_cached_conformers(cached_fname: str) -> Dict[str, np.ndarray]: + """Returns cached dict mapping smile strings to conformer features.""" + return _load_pickle(cached_fname) + + +@functools.lru_cache() +def _get_pcq_dataset(only_smiles: bool): + return lsc.PCQM4MDataset(only_smiles=only_smiles) + + +def _load_pickle(fname: str): + with open(fname, "rb") as f: + return pickle.load(f) diff --git a/ogb_lsc/pcq/download_pcq.py b/ogb_lsc/pcq/download_pcq.py new file mode 100644 index 0000000..4a8c431 --- /dev/null +++ b/ogb_lsc/pcq/download_pcq.py @@ -0,0 +1,99 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Download data required for training and evaluating models.""" + +import pathlib + +from absl import app +from absl import flags +from absl import logging +from google.cloud import storage + +Path = pathlib.Path + + +_BUCKET_NAME = 'deepmind-ogb-lsc' +_MAX_DOWNLOAD_ATTEMPTS = 5 + +FLAGS = flags.FLAGS + +flags.DEFINE_enum('payload', None, ['data', 'models'], + 'Download "data" or "models"?') +flags.DEFINE_string('task_root', None, 'Local task root directory') + +# GCS_DATA_ROOT = Path('pcq/data') +# GCS_MODEL_ROOT = Path('pcq/models') + +DATA_RELATIVE_PATHS = [ + 'raw/data.csv.gz', 'preprocessed/smile_to_conformer.pkl', + 'k_fold_splits' +] + + +class DataCorruptionError(Exception): + pass + + +def _get_gcs_root(): + return Path('pcq') / FLAGS.payload + + +def _get_gcs_bucket(): + storage_client = storage.Client.create_anonymous_client() + return storage_client.bucket(_BUCKET_NAME) + + +def _write_blob_to_destination(blob, task_root, ignore_existing=True): + """Write the blob.""" + logging.info("Copying blob: '%s'", blob.name) + destination_path = Path(task_root) / Path(*Path(blob.name).parts[1:]) + logging.info(" ... to: '%s'", str(destination_path)) + if ignore_existing and destination_path.exists(): + return + destination_path.parent.mkdir(parents=True, exist_ok=True) + checksum = 'crc32c' + for attempt in range(_MAX_DOWNLOAD_ATTEMPTS): + try: + blob.download_to_filename(destination_path.as_posix(), checksum=checksum) + except storage.client.resumable_media.common.DataCorruption: + pass + else: + break + else: + raise DataCorruptionError(f"Checksum ('{checksum}') for {blob.name} failed " + f'after {attempt + 1} attempts') + + +def main(unused_argv): + bucket = _get_gcs_bucket() + if FLAGS.payload == 'data': + relative_paths = DATA_RELATIVE_PATHS + else: + relative_paths = (None,) + for relative_path in relative_paths: + if relative_path is None: + relative_path = str(_get_gcs_root()) + else: + relative_path = str(_get_gcs_root() / relative_path) + logging.info("Copying relative path: '%s'", relative_path) + blobs = bucket.list_blobs(prefix=relative_path) + for blob in blobs: + _write_blob_to_destination(blob, FLAGS.task_root) + + +if __name__ == '__main__': + flags.mark_flag_as_required('payload') + flags.mark_flag_as_required('task_root') + app.run(main) diff --git a/ogb_lsc/pcq/ensemble_predictions.py b/ogb_lsc/pcq/ensemble_predictions.py new file mode 100644 index 0000000..392afbc --- /dev/null +++ b/ogb_lsc/pcq/ensemble_predictions.py @@ -0,0 +1,258 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Script to generate ensembled PCQ test predictions.""" + +import collections +import os +import pathlib +from typing import List, NamedTuple + +from absl import app +from absl import flags +from absl import logging +import dill +import numpy as np +from ogb import lsc + +# pylint: disable=g-bad-import-order +# pytype: disable=import-error +import datasets + +_NUM_SEEDS = 2 + +_CLIP_VALUE = 20. + +_NUM_KFOLD_SPLITS = 10 + +_SEED_START = flags.DEFINE_integer( + 'seed_start', 42, 'Initial seed for the list of ensemble models.') + +_CONFORMER_PATH = flags.DEFINE_string( + 'conformer_path', None, 'Path to conformer predictions.', required=True) + +_NON_CONFORMER_PATH = flags.DEFINE_string( + 'non_conformer_path', + None, + 'Path to non-conformer predictions.', + required=True) + +_OUTPUT_PATH = flags.DEFINE_string('output_path', None, 'Output path.') + +_SPLIT = flags.DEFINE_enum('split', 'test', ['test', 'valid'], + 'Split: valid or test.') + + +class _Predictions(NamedTuple): + predictions: np.ndarray + indices: np.ndarray + + +def _load_dill(fname) -> bytes: + with open(fname, 'rb') as f: + return dill.load(f) + + +def _sort_by_indices(predictions: _Predictions) -> _Predictions: + order = np.argsort(predictions.indices) + return _Predictions( + predictions=predictions.predictions[order], + indices=predictions.indices[order]) + + +def load_predictions(path: str, split: str) -> _Predictions: + """Load written prediction file.""" + if len(os.listdir(path)) != 1: + raise ValueError('Prediction directory must have exactly ' + 'one prediction sub-directory: %s' % path) + + prediction_subdir = os.listdir(path)[0] + return _Predictions(*_load_dill(f'{path}/{prediction_subdir}/{split}.dill')) + + +def mean_mae_distance(x, y): + return np.abs(x - y).mean() + + +def _load_valid_labels() -> np.ndarray: + labels = [label for _, label in datasets.load_smile_strings(with_labels=True)] + return np.array([labels[i] for i in datasets.load_splits()['valid']]) + + +def evaluate_valid_predictions(ensembled_predictions: _Predictions): + """Evaluates the predictions on the validation set.""" + ensembled_predictions = _sort_by_indices(ensembled_predictions) + evaluator = lsc.PCQM4MEvaluator() + results = evaluator.eval( + dict( + y_pred=ensembled_predictions.predictions, + y_true=_load_valid_labels())) + logging.info('MAE on validation dataset: %f', results['mae']) + + +def clip_predictions(predictions: _Predictions) -> _Predictions: + return predictions._replace( + predictions=np.clip(predictions.predictions, 0., _CLIP_VALUE)) + + +def _generate_test_prediction_file(test_predictions: np.ndarray, + output_path: pathlib.Path) -> pathlib.Path: + """Generates the final file for submission.""" + + # Check that predictions are not nuts. + assert test_predictions.dtype in [np.float64, np.float32] + assert not np.any(np.isnan(test_predictions)) + assert np.all(np.isfinite(test_predictions)) + assert test_predictions.min() >= 0. + assert test_predictions.max() <= 40. + + # Too risky to overwrite. + if output_path.exists(): + raise ValueError(f'{output_path} already exists') + + # Write to a local directory, and copy to final path (possibly cns). + # It is not possible to write directlt on CNS. + evaluator = lsc.PCQM4MEvaluator() + + evaluator.save_test_submission(dict(y_pred=test_predictions), output_path) + return output_path + + +def merge_complementary_results(split: str, results_a: _Predictions, + results_b: _Predictions) -> _Predictions: + """Merges two prediction results with no overlap.""" + + indices_a = set(results_a.indices) + indices_b = set(results_b.indices) + assert not indices_a.intersection(indices_b) + + if split == 'test': + merged_indices = list(sorted(indices_a | indices_b)) + expected_indices = datasets.load_splits()[split] + assert np.all(expected_indices == merged_indices) + + predictions = np.concatenate([results_a.predictions, results_b.predictions]) + indices = np.concatenate([results_a.indices, results_b.indices]) + predictions = _sort_by_indices( + _Predictions(indices=indices, predictions=predictions)) + return predictions + + +def ensemble_valid_predictions( + predictions_list: List[_Predictions]) -> _Predictions: + """Ensembles a list of predictions.""" + index_to_predictions = collections.defaultdict(list) + for predictions in predictions_list: + for idx, pred in zip(predictions.indices, predictions.predictions): + index_to_predictions[idx].append(pred) + + for idx, ensemble_list in index_to_predictions.items(): + if len(ensemble_list) != _NUM_SEEDS: + raise RuntimeError( + 'Graph index in the validation set received wrong number of ' + 'predictions to ensemble.') + + index_to_predictions = { + k: np.median(pred_list, axis=0) + for k, pred_list in index_to_predictions.items() + } + return _sort_by_indices( + _Predictions( + indices=np.array(list(index_to_predictions.keys())), + predictions=np.array(list(index_to_predictions.values())))) + + +def ensemble_test_predictions( + predictions_list: List[_Predictions]) -> _Predictions: + """Ensembles a list of predictions.""" + predictions = np.median([pred.predictions for pred in predictions_list], + axis=0) + common_indices = predictions_list[0].indices + for preds in predictions_list[1:]: + assert np.all(preds.indices == common_indices) + + return _Predictions(predictions=predictions, indices=common_indices) + + +def create_submission_from_predictions( + output_path: pathlib.Path, test_predictions: _Predictions) -> pathlib.Path: + """Creates a submission for predictions on a path.""" + assert _SPLIT.value == 'test' + + output_path = _generate_test_prediction_file( + test_predictions.predictions, + output_path=output_path / 'submission_files') + + return output_path / 'y_pred_pcqm4m.npz' + + +def merge_predictions(split: str) -> List[_Predictions]: + """Generates features merged from conformer and non-conformer predictions.""" + merged_predictions: List[_Predictions] = [] + seed = _SEED_START.value + + # Load conformer and non-conformer predictions. + for unused_seed_group in (0, 1): + for k in range(_NUM_KFOLD_SPLITS): + conformer_predictions: _Predictions = load_predictions( + f'{_CONFORMER_PATH.value}/k{k}_seed{seed}', split) + + non_conformer_predictions: _Predictions = load_predictions( + f'{_NON_CONFORMER_PATH.value}/k{k}_seed{seed}', split) + + merged_predictions.append( + merge_complementary_results(_SPLIT.value, conformer_predictions, + non_conformer_predictions)) + + seed += 1 + return merged_predictions + + +def main(_): + split: str = _SPLIT.value + + # Merge conformer and non-conformer predictions. + merged_predictions = merge_predictions(split) + + # Clip before ensembling. + clipped_predictions = list(map(clip_predictions, merged_predictions)) + + # Ensemble predictions. + if split == 'valid': + ensembled_predictions = ensemble_valid_predictions(clipped_predictions) + else: + assert split == 'test' + ensembled_predictions = ensemble_test_predictions(clipped_predictions) + + # Clip after ensembling. + ensembled_predictions = clip_predictions(ensembled_predictions) + + ensembled_predictions_path = pathlib.Path(_OUTPUT_PATH.value) + ensembled_predictions_path.mkdir(parents=True, exist_ok=True) + + with open(ensembled_predictions_path / f'{split}_predictions.dill', + 'wb') as f: + dill.dump(ensembled_predictions, f) + + if split == 'valid': + evaluate_valid_predictions(ensembled_predictions) + else: + assert split == 'test' + output_path = create_submission_from_predictions(ensembled_predictions_path, + ensembled_predictions) + logging.info('Submission files written to %s', output_path) + + +if __name__ == '__main__': + app.run(main) diff --git a/ogb_lsc/pcq/experiment.py b/ogb_lsc/pcq/experiment.py new file mode 100644 index 0000000..522798f --- /dev/null +++ b/ogb_lsc/pcq/experiment.py @@ -0,0 +1,449 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PCQM4M-LSC Jaxline experiment.""" + +import datetime +import functools +import os +import signal +import threading +from typing import Iterable, Mapping, NamedTuple, Tuple + +from absl import app +from absl import flags +from absl import logging +import chex +import dill +import haiku as hk +import jax +from jax.config import config as jax_config +import jax.numpy as jnp +from jaxline import experiment +from jaxline import platform +from jaxline import utils +import jraph +import numpy as np +import optax +import tensorflow as tf +import tree + +# pylint: disable=g-bad-import-order +import dataset_utils +import datasets +import model + + +FLAGS = flags.FLAGS + + +def _get_step_date_label(global_step: int): + # Date removing microseconds. + date_str = datetime.datetime.now().isoformat().split('.')[0] + return f'step_{global_step}_{date_str}' + + +class _Predictions(NamedTuple): + predictions: np.ndarray + indices: np.ndarray + + +def tf1_ema(ema_value, current_value, decay, step): + """Implements EMA with TF1-style decay warmup.""" + decay = jnp.minimum(decay, (1.0 + step) / (10.0 + step)) + return ema_value * decay + current_value * (1 - decay) + + +def _sort_predictions_by_indices(predictions: _Predictions): + sorted_order = np.argsort(predictions.indices) + return _Predictions( + predictions=predictions.predictions[sorted_order], + indices=predictions.indices[sorted_order]) + + +class Experiment(experiment.AbstractExperiment): + """OGB Graph Property Prediction GraphNet experiment.""" + + CHECKPOINT_ATTRS = { + '_params': 'params', + '_opt_state': 'opt_state', + '_network_state': 'network_state', + '_ema_network_state': 'ema_network_state', + '_ema_params': 'ema_params', + } + + def __init__(self, mode, init_rng, config): + """Initializes experiment.""" + super(Experiment, self).__init__(mode=mode, init_rng=init_rng) + if mode not in ('train', 'eval', 'train_eval_multithreaded'): + raise ValueError(f'Invalid mode {mode}.') + + # Do not use accelerators in data pipeline. + tf.config.experimental.set_visible_devices([], device_type='gpu') + tf.config.experimental.set_visible_devices([], device_type='tpu') + + self.mode = mode + self.init_rng = init_rng + self.config = config + + self.loss = None + self.forward = None + + # Needed for checkpoint restore. + self._params = None + self._network_state = None + self._opt_state = None + self._ema_network_state = None + self._ema_params = None + + # _ _ + # | |_ _ __ __ _(_)_ __ + # | __| "__/ _` | | "_ \ + # | |_| | | (_| | | | | | + # \__|_| \__,_|_|_| |_| + # + + def step(self, global_step: jnp.ndarray, rng: jnp.ndarray, **unused_args): + """See Jaxline base class.""" + if self.loss is None: + self._train_init() + + graph = next(self._train_input) + out = self.update_parameters( + self._params, + self._ema_params, + self._network_state, + self._ema_network_state, + self._opt_state, + global_step, + rng, + graph._asdict()) + (self._params, self._ema_params, self._network_state, + self._ema_network_state, self._opt_state, scalars) = out + return utils.get_first(scalars) + + def _construct_loss_config(self): + loss_config = getattr(model, self.config.model.loss_config_name) + if self.config.model.loss_config_name == 'RegressionLossConfig': + return loss_config( + mean=datasets.NORMALIZE_TARGET_MEAN, + std=datasets.NORMALIZE_TARGET_STD, + kwargs=self.config.model.loss_kwargs) + else: + raise ValueError('Unknown Loss Config') + + def _train_init(self): + self.loss = hk.transform_with_state(self._loss) + self._train_input = utils.py_prefetch( + lambda: self._build_numpy_dataset_iterator('train', is_training=True)) + init_stacked_graphs = next(self._train_input) + init_key = utils.bcast_local_devices(self.init_rng) + p_init = jax.pmap(self.loss.init) + self._params, self._network_state = p_init(init_key, + **init_stacked_graphs._asdict()) + + # Learning rate scheduling. + lr_schedule = optax.warmup_cosine_decay_schedule( + **self.config.optimizer.lr_schedule) + + self.optimizer = getattr(optax, self.config.optimizer.name)( + learning_rate=lr_schedule, **self.config.optimizer.optimizer_kwargs) + + self._opt_state = jax.pmap(self.optimizer.init)(self._params) + self.update_parameters = jax.pmap(self._update_parameters, axis_name='i') + if self.config.ema: + self._ema_params = self._params + self._ema_network_state = self._network_state + + def _loss( + self, **graph: Mapping[str, chex.ArrayTree]) -> chex.ArrayTree: + + graph = jraph.GraphsTuple(**graph) + model_instance = model.GraphPropertyEncodeProcessDecode( + loss_config=self._construct_loss_config(), **self.config.model) + loss, scalars = model_instance.get_loss(graph) + return loss, scalars + + def _maybe_save_predictions( + self, + predictions: jnp.ndarray, + split: str, + global_step: jnp.ndarray, + ): + if not self.config.predictions_dir: + return + output_dir = os.path.join(self.config.predictions_dir, + _get_step_date_label(global_step)) + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, split + '.dill') + + with open(output_path, 'wb') as f: + dill.dump(predictions, f) + logging.info('Saved %s predictions at: %s', split, output_path) + + def _build_numpy_dataset_iterator(self, split: str, is_training: bool): + dynamic_batch_size_config = ( + self.config.training.dynamic_batch_size + if is_training else self.config.evaluation.dynamic_batch_size) + + return dataset_utils.build_dataset_iterator( + split=split, + dynamic_batch_size_config=dynamic_batch_size_config, + sample_random=self.config.sample_random, + debug=self.config.debug, + is_training=is_training, + **self.config.dataset_config) + + def _update_parameters( + self, + params: hk.Params, + ema_params: hk.Params, + network_state: hk.State, + ema_network_state: hk.State, + opt_state: optax.OptState, + global_step: jnp.ndarray, + rng: jnp.ndarray, + graph: jraph.GraphsTuple, + ) -> Tuple[hk.Params, hk.Params, hk.State, hk.State, optax.OptState, + chex.ArrayTree]: + """Updates parameters.""" + def get_loss(*x, **graph): + (loss, scalars), network_state = self.loss.apply(*x, **graph) + return loss, (scalars, network_state) + grad_loss_fn = jax.grad(get_loss, has_aux=True) + scaled_grads, (scalars, network_state) = grad_loss_fn( + params, network_state, rng, **graph) + grads = jax.lax.psum(scaled_grads, axis_name='i') + updates, opt_state = self.optimizer.update(grads, opt_state, params) + params = optax.apply_updates(params, updates) + if ema_params is not None: + ema = lambda x, y: tf1_ema(x, y, self.config.ema_decay, global_step) + ema_params = jax.tree_multimap(ema, ema_params, params) + ema_network_state = jax.tree_multimap(ema, ema_network_state, + network_state) + return params, ema_params, network_state, ema_network_state, opt_state, scalars + + # _ + # _____ ____ _| | + # / _ \ \ / / _` | | + # | __/\ V / (_| | | + # \___| \_/ \__,_|_| + # + + def evaluate(self, global_step: jnp.ndarray, rng: jnp.ndarray, + **unused_kwargs) -> chex.ArrayTree: + """See Jaxline base class.""" + if self.forward is None: + self._eval_init() + + if self.config.ema: + params = utils.get_first(self._ema_params) + state = utils.get_first(self._ema_network_state) + else: + params = utils.get_first(self._params) + state = utils.get_first(self._network_state) + rng = utils.get_first(rng) + + split = self.config.evaluation.split + predictions, scalars = self._get_predictions( + params, state, rng, + utils.py_prefetch( + functools.partial( + self._build_numpy_dataset_iterator, split, is_training=False))) + self._maybe_save_predictions(predictions, split, global_step[0]) + return scalars + + def _sum_regression_scalars(self, preds: jnp.ndarray, + graph: jraph.GraphsTuple) -> chex.ArrayTree: + """Creates unnormalised values for accumulation.""" + targets = graph.globals['target'] + graph_mask = jraph.get_graph_padding_mask(graph) + # Sum for accumulation, normalise later since there are a + # variable number of graphs per batch. + mae = model.sum_with_mask(jnp.abs(targets - preds), graph_mask) + mse = model.sum_with_mask((targets - preds)**2, graph_mask) + count = jnp.sum(graph_mask) + return {'values': {'mae': mae.item(), 'mse': mse.item()}, + 'counts': {'mae': count.item(), 'mse': count.item()}} + + def _get_prediction( + self, + params: hk.Params, + state: hk.State, + rng: jnp.ndarray, + graph: jraph.GraphsTuple, + ) -> np.ndarray: + """Returns predictions for all the graphs in the dataset split.""" + model_out, _ = self.eval_apply(params, state, rng, **graph._asdict()) + prediction = np.squeeze(model_out['globals'], axis=1) + return prediction + + def _get_predictions( + self, + params: hk.Params, + state: hk.State, + rng: jnp.ndarray, + graph_iterator: Iterable[jraph.GraphsTuple], + ) -> Tuple[_Predictions, chex.ArrayTree]: + all_scalars = [] + predictions = [] + graph_indices = [] + for i, graph in enumerate(graph_iterator): + prediction = self._get_prediction(params, state, rng, graph) + if 'target' in graph.globals and not jnp.isnan( + graph.globals['target']).any(): + scalars = self._sum_regression_scalars(prediction, graph) + all_scalars.append(scalars) + num_padding_graphs = jraph.get_number_of_padding_with_graphs_graphs(graph) + num_valid_graphs = len(graph.n_node) - num_padding_graphs + depadded_prediction = prediction[:num_valid_graphs] + predictions.append(depadded_prediction) + graph_indices.append(graph.globals['graph_index'][:num_valid_graphs]) + + if i % 1000 == 0: + logging.info('Generated predictions for %d batches so far', i + 1) + + predictions = _sort_predictions_by_indices( + _Predictions( + predictions=np.concatenate(predictions), + indices=np.concatenate(graph_indices))) + + if all_scalars: + sum_all_args = lambda *l: sum(l) + # Sum over graphs in the dataset. + accum_scalars = tree.map_structure(sum_all_args, *all_scalars) + scalars = tree.map_structure(lambda x, y: x / y, accum_scalars['values'], + accum_scalars['counts']) + else: + scalars = {} + return predictions, scalars + + def _eval_init(self): + self.forward = hk.transform_with_state(self._forward) + self.eval_apply = jax.jit(self.forward.apply) + + def _forward(self, **graph: Mapping[str, chex.ArrayTree]) -> chex.ArrayTree: + + graph = jraph.GraphsTuple(**graph) + model_instance = model.GraphPropertyEncodeProcessDecode( + loss_config=self._construct_loss_config(), **self.config.model) + return model_instance(graph) + + +def _restore_state_to_in_memory_checkpointer(restore_path): + """Initializes experiment state from a checkpoint.""" + + # Load pretrained experiment state. + python_state_path = os.path.join(restore_path, 'checkpoint.dill') + with open(python_state_path, 'rb') as f: + pretrained_state = dill.load(f) + logging.info('Restored checkpoint from %s', python_state_path) + + # Assign state to a dummy experiment instance for the in-memory checkpointer, + # broadcasting to devices. + dummy_experiment = Experiment( + mode='train', init_rng=0, config=FLAGS.config.experiment_kwargs.config) + for attribute, key in Experiment.CHECKPOINT_ATTRS.items(): + setattr(dummy_experiment, attribute, + utils.bcast_local_devices(pretrained_state[key])) + + jaxline_state = dict( + global_step=pretrained_state['global_step'], + experiment_module=dummy_experiment) + snapshot = utils.SnapshotNT(0, jaxline_state) + + # Finally, seed the jaxline `utils.InMemoryCheckpointer` global dict. + utils.GLOBAL_CHECKPOINT_DICT['latest'] = utils.CheckpointNT( + threading.local(), [snapshot]) + + +def _save_state_from_in_memory_checkpointer( + save_path, experiment_class: experiment.AbstractExperiment): + """Saves experiment state to a checkpoint.""" + logging.info('Saving model.') + for checkpoint_name, checkpoint in utils.GLOBAL_CHECKPOINT_DICT.items(): + if not checkpoint.history: + logging.info('Nothing to save in "%s"', checkpoint_name) + continue + + pickle_nest = checkpoint.history[-1].pickle_nest + global_step = pickle_nest['global_step'] + + state_dict = {'global_step': global_step} + for attribute, key in experiment_class.CHECKPOINT_ATTRS.items(): + state_dict[key] = utils.get_first( + getattr(pickle_nest['experiment_module'], attribute)) + save_dir = os.path.join( + save_path, checkpoint_name, _get_step_date_label(global_step)) + python_state_path = os.path.join(save_dir, 'checkpoint.dill') + os.makedirs(save_dir, exist_ok=True) + with open(python_state_path, 'wb') as f: + dill.dump(state_dict, f) + logging.info( + 'Saved "%s" checkpoint to %s', checkpoint_name, python_state_path) + + +def _setup_signals(save_model_fn): + """Sets up a signal for model saving.""" + # Save a model on Ctrl+C. + def sigint_handler(unused_sig, unused_frame): + # Ideally, rather than saving immediately, we would then "wait" for a good + # time to save. In practice this reads from an in-memory checkpoint that + # only saves every 30 seconds or so, so chances of race conditions are very + # small. + save_model_fn() + logging.info(r'Use `Ctrl+\` to save and exit.') + + # Exit on `Ctrl+\`, saving a model. + prev_sigquit_handler = signal.getsignal(signal.SIGQUIT) + def sigquit_handler(unused_sig, unused_frame): + # Restore previous handler early, just in case something goes wrong in the + # next lines, so it is possible to press again and exit. + signal.signal(signal.SIGQUIT, prev_sigquit_handler) + save_model_fn() + logging.info(r'Exiting on `Ctrl+\`') + + # Re-raise for clean exit. + os.kill(os.getpid(), signal.SIGQUIT) + + signal.signal(signal.SIGINT, sigint_handler) + signal.signal(signal.SIGQUIT, sigquit_handler) + + +def main(argv, experiment_class: experiment.AbstractExperiment): + + # Maybe restore a model. + restore_path = FLAGS.config.restore_path + if restore_path: + _restore_state_to_in_memory_checkpointer(restore_path) + + # Maybe save a model. + save_dir = os.path.join(FLAGS.config.checkpoint_dir, 'models') + if FLAGS.config.one_off_evaluate: + save_model_fn = lambda: None # No need to save checkpoint in this case. + else: + save_model_fn = functools.partial( + _save_state_from_in_memory_checkpointer, save_dir, experiment_class) + _setup_signals(save_model_fn) # Save on Ctrl+C (continue) or Ctrl+\ (exit). + + try: + platform.main(experiment_class, argv) + finally: + save_model_fn() # Save at the end of training or in case of exception. + + +if __name__ == '__main__': + jax_config.update('jax_debug_nans', False) + flags.mark_flag_as_required('config') + app.run(lambda argv: main(argv, Experiment)) diff --git a/ogb_lsc/pcq/generate_conformer_features.py b/ogb_lsc/pcq/generate_conformer_features.py new file mode 100644 index 0000000..fa3328e --- /dev/null +++ b/ogb_lsc/pcq/generate_conformer_features.py @@ -0,0 +1,67 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate conformer features to be used for training/predictions.""" + +import multiprocessing as mp +import pickle +from typing import List + +from absl import app +from absl import flags +import numpy as np + +# pylint: disable=g-bad-import-order +import conformer_utils +import datasets + +_SPLITS = flags.DEFINE_spaceseplist( + 'splits', ['test'], 'Splits to compute conformer features for.') + +_OUTPUT_FILE = flags.DEFINE_string( + 'output_file', + None, + required=True, + help='Output file name to write the generated conformer features to.') + +_NUM_PROCS = flags.DEFINE_integer( + 'num_parallel_procs', 64, + 'Number of parallel processes to use for conformer generation.') + + +def generate_conformer_features(smiles: List[str]) -> List[np.ndarray]: + # Conformer generation is a CPU-bound task and hence can get a boost from + # parallel processing. + # To avoid GIL, we choose multiprocessing instead of the + # simpler multi-threading option here for parallel computing. + with mp.Pool(_NUM_PROCS.value) as pool: + return list(pool.map(conformer_utils.compute_conformer, smiles)) + + +def main(_): + smiles = datasets.load_smile_strings(with_labels=False) + indices = set() + for split in _SPLITS.value: + indices.update(datasets.load_splits()[split]) + + smiles = [smiles[i] for i in sorted(indices)] + conformers = generate_conformer_features(smiles) + smiles_to_conformers = dict(zip(smiles, conformers)) + + with open(_OUTPUT_FILE.value, 'wb') as f: + pickle.dump(smiles_to_conformers, f) + + +if __name__ == '__main__': + app.run(main) diff --git a/ogb_lsc/pcq/generate_validation_splits.py b/ogb_lsc/pcq/generate_validation_splits.py new file mode 100644 index 0000000..264aa9b --- /dev/null +++ b/ogb_lsc/pcq/generate_validation_splits.py @@ -0,0 +1,50 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generates the k-fold validation splits.""" + +import os +import pickle + +from absl import app +from absl import flags +from absl import logging +import numpy as np + +# pylint: disable=g-bad-import-order +import datasets + + +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', None, required=True, + help='Output directory to write the splits to') + + +K = 10 + + +def main(argv): + del argv + valid_indices = datasets.load_splits()['valid'] + k_splits = np.split(valid_indices, K) + os.makedirs(_OUTPUT_DIR.value, exist_ok=True) + for k_i, split in enumerate(k_splits): + fname = os.path.join(_OUTPUT_DIR.value, f'{k_i}.pkl') + with open(fname, 'wb') as f: + pickle.dump(split, f) + logging.info('Saved: %s', fname) + + +if __name__ == '__main__': + app.run(main) diff --git a/ogb_lsc/pcq/model.py b/ogb_lsc/pcq/model.py new file mode 100644 index 0000000..6bb2b7b --- /dev/null +++ b/ogb_lsc/pcq/model.py @@ -0,0 +1,525 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""PCQM4M-LSC models.""" + +import copy +import functools +from typing import Any, Dict, Mapping, Sequence, Tuple + +import chex +import haiku as hk +import jax +import jax.numpy as jnp +import jraph +from ml_collections import config_dict + + +_REDUCER_NAMES = { + "sum": jax.ops.segment_sum, + "mean": jraph.segment_mean, +} + +_NUM_EDGE_FEATURES = 13 +_NUM_NODE_FEATURES = 173 + + +@chex.dataclass +class RegressionLossConfig: + """Regression Loss Config.""" + # For normalization and denormalization. + std: float + mean: float + kwargs: Mapping[str, Any] + out_size: int = 1 + + +def _sigmoid_cross_entropy( + logits: jnp.DeviceArray, + labels: jnp.DeviceArray, +) -> jnp.DeviceArray: + log_p = jax.nn.log_sigmoid(logits) + log_not_p = jax.nn.log_sigmoid(-logits) + return -labels * log_p - (1. - labels) * log_not_p + + +def _softmax_cross_entropy( + logits: jnp.DeviceArray, + targets: jnp.DeviceArray, +) -> jnp.DeviceArray: + logits = jax.nn.log_softmax(logits) + return -jnp.sum(targets * logits, axis=-1) + + +def _regression_loss( + pred: jnp.ndarray, + targets: jnp.ndarray, + exponent: int, +) -> jnp.ndarray: + """Regression loss.""" + error = pred - targets + if exponent == 2: + return error ** 2 + elif exponent == 1: + return jnp.abs(error) + else: + raise ValueError(f"Unsupported exponent value {exponent}.") + + +def _build_mlp( + name: str, + output_sizes: Sequence[int], + use_layer_norm=False, + activation=jax.nn.relu, +): + """Builds an MLP, optionally with layernorm.""" + net = hk.nets.MLP( + output_sizes=output_sizes, name=name + "_mlp", activation=activation) + if use_layer_norm: + layer_norm = hk.LayerNorm( + axis=-1, + create_scale=True, + create_offset=True, + name=name + "_layer_norm") + net = hk.Sequential([net, layer_norm]) + return jraph.concatenated_args(net) + + +def _compute_relative_displacement_and_distance( + graph: jraph.GraphsTuple, + normalization_factor: float, + use_target: bool, +) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Computes relative displacements and distances.""" + + if use_target: + node_positions = graph.nodes["positions_targets"] + else: + node_positions = graph.nodes["positions"] + relative_displacement = node_positions[ + graph.receivers] - node_positions[graph.senders] + + # Note due to the random rotations in space, mean across all nodes across + # all batches is guaranteed to be zero, and the standard deviation is + # guaranteed to be the same for all 3 coordinates, so we only need to scale + # by a single value. + relative_displacement /= normalization_factor + relative_distance = jnp.linalg.norm( + relative_displacement, axis=-1, keepdims=True) + return relative_displacement, relative_distance + + +def _broadcast_global_to_nodes( + global_feature: jnp.ndarray, + graph: jraph.GraphsTuple, +) -> jnp.ndarray: + graph_idx = jnp.arange(graph.n_node.shape[0]) + sum_n_node = jax.tree_leaves(graph.nodes)[0].shape[0] + node_graph_idx = jnp.repeat( + graph_idx, graph.n_node, axis=0, total_repeat_length=sum_n_node) + return global_feature[node_graph_idx] + + +def _broadcast_global_to_edges( + global_feature: jnp.ndarray, + graph: jraph.GraphsTuple, +) -> jnp.ndarray: + graph_idx = jnp.arange(graph.n_edge.shape[0]) + sum_n_edge = graph.senders.shape[0] + edge_graph_idx = jnp.repeat( + graph_idx, graph.n_edge, axis=0, total_repeat_length=sum_n_edge) + return global_feature[edge_graph_idx] + + +class GraphPropertyEncodeProcessDecode(hk.Module): + """Encode-process-decode model for graph property prediction.""" + + def __init__( + self, + loss_config: config_dict.ConfigDict, + mlp_hidden_size: int, + mlp_layers: int, + latent_size: int, + use_layer_norm: bool, + num_message_passing_steps: int, + shared_message_passing_weights: bool, + mask_padding_graph_at_every_step: bool, + loss_config_name: str, + loss_kwargs: config_dict.ConfigDict, + processor_mode: str, + global_reducer: str, + node_reducer: str, + dropedge_rate: float, + dropnode_rate: float, + aux_multiplier: float, + ignore_globals: bool, + ignore_globals_from_final_layer_for_predictions: bool, + add_relative_distance: bool = False, + add_relative_displacement: bool = False, + add_absolute_positions: bool = False, + position_normalization: float = 1., + relative_displacement_normalization: float = 1., + add_misc_node_features: bool = None, + name="GraphPropertyEncodeProcessDecode", + ): + super(GraphPropertyEncodeProcessDecode, self).__init__() + self._loss_config = loss_config + + self._config = config_dict.ConfigDict(dict( + loss_config=loss_config, + mlp_hidden_size=mlp_hidden_size, + mlp_layers=mlp_layers, + latent_size=latent_size, + use_layer_norm=use_layer_norm, + num_message_passing_steps=num_message_passing_steps, + shared_message_passing_weights=shared_message_passing_weights, + mask_padding_graph_at_every_step=mask_padding_graph_at_every_step, + loss_config_name=loss_config_name, + loss_kwargs=loss_kwargs, + processor_mode=processor_mode, + global_reducer=global_reducer, + node_reducer=node_reducer, + dropedge_rate=dropedge_rate, + dropnode_rate=dropnode_rate, + aux_multiplier=aux_multiplier, + ignore_globals=ignore_globals, + ignore_globals_from_final_layer_for_predictions=ignore_globals_from_final_layer_for_predictions, + add_relative_distance=add_relative_distance, + add_relative_displacement=add_relative_displacement, + add_absolute_positions=add_absolute_positions, + position_normalization=position_normalization, + relative_displacement_normalization=relative_displacement_normalization, + add_misc_node_features=add_misc_node_features, + )) + + def __call__(self, graph: jraph.GraphsTuple) -> chex.ArrayTree: + """Model inference step.""" + out = self._forward(graph, is_training=False) + if isinstance(self._loss_config, RegressionLossConfig): + out["globals"] = out[ + "globals"]*self._loss_config.std + self._loss_config.mean + return out + + @hk.experimental.name_like("__call__") + def get_loss( + self, + graph: jraph.GraphsTuple, + is_training: bool = True, + ) -> Tuple[jnp.ndarray, chex.ArrayTree]: + """Model loss.""" + scalars = get_utilization_scalars(graph) + targets = copy.deepcopy(graph.globals["target"]) + if len(targets.shape) == 1: + targets = targets[:, None] + del graph.globals["target"] + target_mask = None + if "target_mask" in graph.globals: + target_mask = copy.deepcopy(graph.globals["target_mask"]) + del graph.globals["target_mask"] + + out = self._forward(graph, is_training) + + if isinstance(self._loss_config, RegressionLossConfig): + normalized_targets = ( + (targets - self._loss_config.mean) / self._loss_config.std) + per_graph_and_head_loss = _regression_loss( + out["globals"], normalized_targets, **self._loss_config.kwargs) + else: + raise TypeError(type(self._loss_config)) + + # Mask out nans + if target_mask is None: + per_graph_and_head_loss = jnp.mean(per_graph_and_head_loss, axis=1) + else: + per_graph_and_head_loss = jnp.sum( + per_graph_and_head_loss * target_mask, axis=1) + per_graph_and_head_loss /= jnp.sum(target_mask + 1e-8, axis=1) + + g_mask = jraph.get_graph_padding_mask(graph) + loss = _mean_with_mask(per_graph_and_head_loss, g_mask) + scalars.update({"loss": loss}) + + if self._config.aux_multiplier > 0: + atom_loss = self._get_node_auxiliary_loss( + graph, out["atom_one_hots"], graph.nodes["atom_one_hots_targets"], + is_regression=False) + bond_loss = self._get_edge_auxiliary_loss( + graph, out["bond_one_hots"], graph.edges["bond_one_hots_targets"], + is_regression=False) + loss += (atom_loss + bond_loss)*self._config.aux_multiplier + scalars.update({"atom_loss": atom_loss, "bond_loss": bond_loss}) + + scaled_loss = loss / jax.device_count() + scalars.update({"total_loss": loss}) + return scaled_loss, scalars + + @hk.transparent + def _prepare_features(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: + """Prepares features keys into flat node, edge and global features.""" + + # Collect edge features. + edge_features_list = [graph.edges["bond_one_hots"]] + if (self._config.add_relative_displacement or + self._config.add_relative_distance): + (relative_displacement, relative_distance + ) = _compute_relative_displacement_and_distance( + graph, self._config.relative_displacement_normalization, + use_target=False) + + if self._config.add_relative_displacement: + edge_features_list.append(relative_displacement) + if self._config.add_relative_distance: + edge_features_list.append(relative_distance) + mask_at_edges = _broadcast_global_to_edges( + graph.globals["positions_nan_mask"], graph) + edge_features_list.append(mask_at_edges[:, None].astype(jnp.float32)) + + edge_features = jnp.concatenate(edge_features_list, axis=-1) + + # Collect node features + node_features_list = [graph.nodes["atom_one_hots"]] + + if self._config.add_absolute_positions: + node_features_list.append( + graph.nodes["positions"] / self._config.position_normalization) + mask_at_nodes = _broadcast_global_to_nodes( + graph.globals["positions_nan_mask"], graph) + node_features_list.append(mask_at_nodes[:, None].astype(jnp.float32)) + + node_features = jnp.concatenate(node_features_list, axis=-1) + + global_features = jnp.zeros((len(graph.n_node), self._config.latent_size)) + chex.assert_tree_shape_prefix(global_features, (len(graph.n_node),)) + return graph._replace( + nodes=node_features, edges=edge_features, globals=global_features) + + @hk.transparent + def _encoder( + self, + graph: jraph.GraphsTuple, + is_training: bool, + ) -> jraph.GraphsTuple: + """Builds the encoder.""" + del is_training # unused + graph = self._prepare_features(graph) + + # Run encoders in all of the node, edge and global features. + output_sizes = [self._config.mlp_hidden_size] * self._config.mlp_layers + output_sizes += [self._config.latent_size] + build_mlp = functools.partial( + _build_mlp, + output_sizes=output_sizes, + use_layer_norm=self._config.use_layer_norm, + ) + + gmf = jraph.GraphMapFeatures( + embed_edge_fn=build_mlp("edge_encoder"), + embed_node_fn=build_mlp("node_encoder"), + embed_global_fn=None + if self._config.ignore_globals else build_mlp("global_encoder"), + ) + return gmf(graph) + + @hk.transparent + def _processor( + self, + graph: jraph.GraphsTuple, + is_training: bool, + ) -> jraph.GraphsTuple: + """Builds the processor.""" + output_sizes = [self._config.mlp_hidden_size] * self._config.mlp_layers + output_sizes += [self._config.latent_size] + build_mlp = functools.partial( + _build_mlp, + output_sizes=output_sizes, + use_layer_norm=self._config.use_layer_norm, + ) + + shared_weights = self._config.shared_message_passing_weights + node_reducer = _REDUCER_NAMES[self._config.node_reducer] + global_reducer = _REDUCER_NAMES[self._config.global_reducer] + + def dropout_if_training(fn, dropout_rate: float): + def wrapped(*args): + out = fn(*args) + if is_training: + mask = hk.dropout(hk.next_rng_key(), dropout_rate, + jnp.ones([out.shape[0], 1])) + out = out * mask + return out + return wrapped + + num_mps = self._config.num_message_passing_steps + for step in range(num_mps): + if step == 0 or not shared_weights: + suffix = "shared" if shared_weights else step + + update_edge_fn = dropout_if_training( + build_mlp(f"edge_processor_{suffix}"), + dropout_rate=self._config.dropedge_rate) + + update_node_fn = dropout_if_training( + build_mlp(f"node_processor_{suffix}"), + dropout_rate=self._config.dropnode_rate) + + if self._config.ignore_globals: + gnn = jraph.InteractionNetwork( + update_edge_fn=update_edge_fn, + update_node_fn=update_node_fn, + aggregate_edges_for_nodes_fn=node_reducer) + else: + gnn = jraph.GraphNetwork( + update_edge_fn=update_edge_fn, + update_node_fn=update_node_fn, + update_global_fn=build_mlp(f"global_processor_{suffix}"), + aggregate_edges_for_nodes_fn=node_reducer, + aggregate_nodes_for_globals_fn=global_reducer, + aggregate_edges_for_globals_fn=global_reducer, + ) + + mode = self._config.processor_mode + + if mode == "mlp": + graph = gnn(graph) + + elif mode == "resnet": + new_graph = gnn(graph) + graph = graph._replace( + nodes=graph.nodes + new_graph.nodes, + edges=graph.edges + new_graph.edges, + globals=graph.globals + new_graph.globals, + ) + else: + raise ValueError(f"Unknown processor_mode `{mode}`") + + if self._config.mask_padding_graph_at_every_step: + graph = _mask_out_padding_graph(graph) + + return graph + + @hk.transparent + def _decoder( + self, + graph: jraph.GraphsTuple, + input_graph: jraph.GraphsTuple, + is_training: bool, + ) -> chex.ArrayTree: + """Builds the decoder.""" + del is_training # unused. + + output_sizes = [self._config.mlp_hidden_size] * self._config.mlp_layers + output_sizes += [self._loss_config.out_size] + net = _build_mlp("regress_out", output_sizes, use_layer_norm=False) + summed_nodes = _aggregate_nodes_to_globals(graph, graph.nodes) + inputs_to_global_decoder = [summed_nodes] + if not self._config.ignore_globals_from_final_layer_for_predictions: + inputs_to_global_decoder.append(graph.globals) + + out = net(jnp.concatenate(inputs_to_global_decoder, axis=-1)) + out_dict = {} + out_dict["globals"] = out + + # Note "linear" names are for compatibility with pre-trained model names. + out_dict["bond_one_hots"] = hk.Linear( + _NUM_EDGE_FEATURES, name="linear")(graph.edges) + out_dict["atom_one_hots"] = hk.Linear( + _NUM_NODE_FEATURES, name="linear_1")(graph.nodes) + return out_dict + + @hk.transparent + def _forward(self, graph: jraph.GraphsTuple, is_training: bool): + input_graph = jraph.GraphsTuple(*graph) + with hk.experimental.name_scope("encoder_scope"): + graph = self._encoder(graph, is_training) + with hk.experimental.name_scope("processor_scope"): + graph = self._processor(graph, is_training) + with hk.experimental.name_scope("decoder_scope"): + out = self._decoder(graph, input_graph, is_training) + return out + + def _get_node_auxiliary_loss( + self, graph, pred, targets, is_regression, additional_mask=None): + loss = self._get_loss(pred, targets, is_regression) + target_mask = jraph.get_node_padding_mask(graph) + + if additional_mask is not None: + loss *= additional_mask + target_mask = jnp.logical_and(target_mask, additional_mask) + + return _mean_with_mask(loss, target_mask) + + def _get_edge_auxiliary_loss( + self, graph, pred, targets, is_regression, additional_mask=None): + loss = self._get_loss(pred, targets, is_regression) + target_mask = jraph.get_edge_padding_mask(graph) + + if additional_mask is not None: + loss *= additional_mask + target_mask = jnp.logical_and(target_mask, additional_mask) + + return _mean_with_mask(loss, target_mask) + + def _get_loss(self, pred, targets, is_regression): + if is_regression: + loss = ((pred - targets)**2).mean(axis=-1) + else: + targets /= jnp.maximum(1., jnp.sum(targets, axis=-1, keepdims=True)) + loss = _softmax_cross_entropy(pred, targets) + return loss + + +def get_utilization_scalars( + padded_graph: jraph.GraphsTuple) -> Dict[str, float]: + padding_nodes = jraph.get_number_of_padding_with_graphs_nodes(padded_graph) + all_nodes = len(jax.tree_leaves(padded_graph.nodes)[0]) + padding_edges = jraph.get_number_of_padding_with_graphs_edges(padded_graph) + all_edges = len(jax.tree_leaves(padded_graph.edges)[0]) + padding_graphs = jraph.get_number_of_padding_with_graphs_graphs(padded_graph) + all_graphs = len(padded_graph.n_node) + return {"node_utilization": 1 - (padding_nodes / all_nodes), + "edge_utilization": 1 - (padding_edges / all_edges), + "graph_utilization": 1 - (padding_graphs / all_graphs)} + + +def sum_with_mask(array: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray: + return (mask * array).sum(0) + + +def _mean_with_mask(array: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray: + num_valid_rows = mask.sum(0) + return sum_with_mask(array, mask) / num_valid_rows + + +def _mask_out_padding_graph( + padded_graph: jraph.GraphsTuple) -> jraph.GraphsTuple: + return padded_graph._replace( + nodes=jnp.where( + jraph.get_node_padding_mask( + padded_graph)[:, None], padded_graph.nodes, 0.), + edges=jnp.where( + jraph.get_edge_padding_mask( + padded_graph)[:, None], padded_graph.edges, 0.), + globals=jnp.where( + jraph.get_graph_padding_mask( + padded_graph)[:, None], padded_graph.globals, 0.), + ) + + +def _aggregate_nodes_to_globals(graph, node_features): + n_graph = graph.n_node.shape[0] + sum_n_node = jax.tree_leaves(graph.nodes)[0].shape[0] + graph_idx = jnp.arange(n_graph) + node_gr_idx = jnp.repeat( + graph_idx, graph.n_node, axis=0, total_repeat_length=sum_n_node) + return jax.ops.segment_sum(node_features, node_gr_idx, num_segments=n_graph) diff --git a/ogb_lsc/pcq/requirements.txt b/ogb_lsc/pcq/requirements.txt new file mode 100644 index 0000000..0168a7b --- /dev/null +++ b/ogb_lsc/pcq/requirements.txt @@ -0,0 +1,21 @@ +wheel +absl-py>=0.12.0 +chex>=0.0.7 +dill>=0.3.3 +dm-haiku>=0.0.4 +# jaxline +git+https://github.com/deepmind/jaxline@927695a0b88e480d085590736e2daa36c2f2c68b +optax>=0.0.8 +ml_collections +numpy>=1.16.4 +tensorflow>=2.5.0 +tensorflow-datasets>=4.3.0 +dm-tree>=0.1.6 +ogb>=1.3.1 +# graph_nets +git+git://github.com/deepmind/graph_nets.git@64771dff0d74ca8e77b1f1dcd5a7d26634356d61 +rdkit-pypi>=2021.3.2.3 +jaxlib>=0.1.57+cuda110 +# jraph +git+git://github.com/deepmind/jraph.git@80d2f9e4d82f841a71d56ba44fa9e8781e93ae1f +google-cloud-storage diff --git a/ogb_lsc/pcq/run_preprocessing.sh b/ogb_lsc/pcq/run_preprocessing.sh new file mode 100755 index 0000000..307d9b1 --- /dev/null +++ b/ogb_lsc/pcq/run_preprocessing.sh @@ -0,0 +1,48 @@ +#!/bin/bash +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e +set -x + +while getopts ":r:" opt; do + case ${opt} in + r ) + TASK_ROOT=$OPTARG + ;; + \? ) + echo "Usage: run_training.sh -r " + ;; + : ) + echo "Invalid option: $OPTARG requires an argument" 1>&2 + ;; + esac +done +shift $((OPTIND -1)) + +# Get this script's directory. +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + + +DATA_ROOT=${TASK_ROOT}/data/ + + +python "${SCRIPT_DIR}"/generate_validation_splits.py \ + --output_dir="${DATA_ROOT}/k_fold_splits" + +mkdir -p ${DATA_ROOT}/preprocessed/ +python "${SCRIPT_DIR}"/generate_conformer_features.py \ + --splits="test valid train" \ + --num_parallel_procs=32 \ + --output_file="${DATA_ROOT}/preprocessed/smile_to_conformer.pkl" diff --git a/ogb_lsc/pcq/run_pretrained_eval.sh b/ogb_lsc/pcq/run_pretrained_eval.sh new file mode 100755 index 0000000..0b91143 --- /dev/null +++ b/ogb_lsc/pcq/run_pretrained_eval.sh @@ -0,0 +1,107 @@ +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Generates test predictions from trained model checkpoints. + +set -e +set -x + +while getopts ":r:" opt; do + case ${opt} in + r ) + TASK_ROOT=$OPTARG + ;; + \? ) + echo "Usage: run_pretrained_eval.sh -r " + ;; + : ) + echo "Invalid option: $OPTARG requires an argument" 1>&2 + ;; + esac +done +shift $((OPTIND -1)) + +# Get this script's directory. +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + +SPLIT="test" + +DATA_ROOT=${TASK_ROOT}/data/ # For valid k-fold splits. +MODELS_ROOT=${TASK_ROOT}/models/ +CACHED_CONFORMERS_DIR=${TASK_ROOT}/online_conformers/$SPLIT +CACHED_CONFORMERS_FILE=${CACHED_CONFORMERS_DIR}/smiles_to_conformers.pkl +OUTPUT_DIR=${TASK_ROOT}/predictions/ + +# First generate and cache the conformer feature data for the test split. +mkdir -p ${CACHED_CONFORMERS_DIR} +time python "${SCRIPT_DIR}"/generate_conformer_features.py \ + --splits=$SPLIT \ + --num_parallel_procs=32 \ + --output_file=${CACHED_CONFORMERS_FILE} + +seed=42 +# Share GPU between two runs by disabling XLA memory pre-allocation. +export XLA_PYTHON_CLIENT_PREALLOCATE=false + +for seed_group in 0 1; do + for k in `seq 0 9`; do + # Conformer + predictions_dir="${OUTPUT_DIR}/predictions/${SPLIT}/conformer/k${k}_seed${seed}" + time python "${SCRIPT_DIR}"/experiment.py \ + --jaxline_mode="eval" \ + --config="${SCRIPT_DIR}"/config.py \ + --config.experiment_kwargs.config.evaluation.split=$SPLIT \ + --config.restore_path=${MODELS_ROOT}/conformer/k${k}_seed${seed} \ + --config.experiment_kwargs.config.predictions_dir=${predictions_dir} \ + --config.experiment_kwargs.config.dataset_config.data_root=${DATA_ROOT} \ + --config.experiment_kwargs.config.dataset_config.k_fold_split_id=$k \ + --config.experiment_kwargs.config.dataset_config.cached_conformers_file=${CACHED_CONFORMERS_FILE} \ + --config.experiment_kwargs.config.dataset_config.filter_in_or_out_samples_with_nans_in_conformers="out" \ + --config.experiment_kwargs.config.model.latent_size=256 \ + --config.experiment_kwargs.config.model.mlp_hidden_size=1024 \ + --config.experiment_kwargs.config.model.num_message_passing_steps=32 \ + --config.one_off_evaluate & + + + # Non-Conformer + predictions_dir="${OUTPUT_DIR}/predictions/${SPLIT}/non_conformer/k${k}_seed${seed}" + time python "${SCRIPT_DIR}"/experiment.py \ + --jaxline_mode="eval" \ + --config="${SCRIPT_DIR}"/config.py \ + --config.experiment_kwargs.config.evaluation.split=$SPLIT \ + --config.restore_path=${MODELS_ROOT}/non_conformer/k${k}_seed${seed} \ + --config.experiment_kwargs.config.predictions_dir=${predictions_dir} \ + --config.experiment_kwargs.config.dataset_config.data_root=${DATA_ROOT} \ + --config.experiment_kwargs.config.dataset_config.k_fold_split_id=$k \ + --config.experiment_kwargs.config.dataset_config.cached_conformers_file=${CACHED_CONFORMERS_FILE} \ + --config.experiment_kwargs.config.dataset_config.filter_in_or_out_samples_with_nans_in_conformers="in" \ + --config.experiment_kwargs.config.model.num_message_passing_steps=50 \ + --config.experiment_kwargs.config.model.add_relative_distance=false \ + --config.experiment_kwargs.config.model.add_relative_displacement=false \ + --config.one_off_evaluate & + + wait + + ((seed=seed+1)) + + done +done + +# Ensemble predictions. +time python "${SCRIPT_DIR}"/ensemble_predictions.py \ + --seed_start=42 \ + --split=$SPLIT \ + --conformer_path="${OUTPUT_DIR}/predictions/${SPLIT}/conformer" \ + --non_conformer_path="${OUTPUT_DIR}/predictions/${SPLIT}/non_conformer" \ + --output_path="${OUTPUT_DIR}/ensembled_predictions/" diff --git a/ogb_lsc/pcq/run_training.sh b/ogb_lsc/pcq/run_training.sh new file mode 100755 index 0000000..8cd14e3 --- /dev/null +++ b/ogb_lsc/pcq/run_training.sh @@ -0,0 +1,134 @@ +#!/bin/bash +# Copyright 2021 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e +set -x + +while getopts ":r:" opt; do + case ${opt} in + r ) + TASK_ROOT=$OPTARG + ;; + \? ) + echo "Usage: run_training.sh -r " + ;; + : ) + echo "Invalid option: $OPTARG requires an argument" 1>&2 + ;; + esac +done +shift $((OPTIND -1)) + +# Get this script's directory. +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" + +echo " +These scripts are provided for illustrative purposes. It is not practical for +actual training since it only uses a single machine, and likely requires +reducing the batch size and/or model size to fit on a single GPU. + +For the actual submission we used training distributed in different ways: +* We used 4x Cloud TPU v4s to implement batch parallelism, so batch + size is effectively 8 times larger than the values in the config. +* We ran the online early-stopping evaluator on a separate machine with an + NVIDIA V100 GPU. +* We run identical replicas of all of the above for each of the 40 models + trained. + +Using those mechanisms, training runs at ~4 and 6 steps per second, reaching +500k steps in about 1.5 days. +" + +echo " +Pre-requisites (See README): +* cd into directory containing the 'ogb_lsc' folder. +* Python dependencies have been installed. +* pre-processed data is available at the hardcoded paths +" + +read -p "Press enter to continue" + +DATA_ROOT=${TASK_ROOT}/data/ +CHECKPOINT_DIR=${TASK_ROOT}/checkpoints/ +CACHED_CONFORMERS_FILE="${DATA_ROOT}/preprocessed/smile_to_conformer.pkl" + +# We run two seeds for each model of the k=10 k-fold. +BASE_SEED=42 +for SEED_OFFSET in 0 10; do +for K_FOLD_INDEX in {0..9}; do + MODEL_SEED=`expr ${BASE_SEED} + ${SEED_OFFSET} + ${K_FOLD_INDEX}` + + echo "Running k=${K_FOLD_INDEX} with init seed ${MODEL_SEED}" + + # This runs training (each model is trained on train split + 90% of the + # validation split) with early stopping, storing both "latest" model and + # "best" early-stopped model at `--config.checkpoint_dir`. + + # Models are early stopped based on accuracy of on 10% of the validation data + # (each K_FOLD_INDEX leaves a different 10% of data out from training) left + # out from training. + + # Models are stored at the end of training. Intermediate models can also be + # stored while training by sending a SIGINT signal (Ctrl+C) which will not + # interrupt the training. + + # It is possible to interrupt training using (Ctrl+\) and then continue + # passing the corresponding `--config.restore_path=${RESTORE_PATH}` which is + # stored when interrupting. + + # Conformer. + SUFFIX="conformer/k${K_FOLD_INDEX}_seed${MODEL_SEED}" + echo "Running ${SUFFIX}" + python "${SCRIPT_DIR}"/experiment.py \ + --jaxline_mode="train_eval_multithreaded" \ + --config="${SCRIPT_DIR}"/config.py \ + --config.random_seed=${MODEL_SEED} \ + --config.checkpoint_dir=${CHECKPOINT_DIR}/${SUFFIX} \ + --config.experiment_kwargs.config.dataset_config.data_root=${DATA_ROOT} \ + --config.experiment_kwargs.config.dataset_config.k_fold_split_id=${K_FOLD_INDEX} \ + --config.experiment_kwargs.config.dataset_config.num_k_fold_splits=10 \ + --config.experiment_kwargs.config.dataset_config.cached_conformers_file=${CACHED_CONFORMERS_FILE} \ + --config.experiment_kwargs.config.model.latent_size=256 \ + --config.experiment_kwargs.config.model.mlp_hidden_size=1024 \ + --config.experiment_kwargs.config.model.num_message_passing_steps=32 \ + --config.experiment_kwargs.config.evaluation.split="valid" + + + # Non-Conformer. + SUFFIX="non_conformer/k${K_FOLD_INDEX}_seed${MODEL_SEED}" + echo "Running ${SUFFIX}" + python "${SCRIPT_DIR}"/experiment.py \ + --jaxline_mode="train_eval_multithreaded" \ + --config="${SCRIPT_DIR}"/config.py \ + --config.random_seed=${MODEL_SEED} \ + --config.checkpoint_dir=${CHECKPOINT_DIR}/${SUFFIX} \ + --config.experiment_kwargs.config.dataset_config.data_root=${DATA_ROOT} \ + --config.experiment_kwargs.config.dataset_config.k_fold_split_id=${K_FOLD_INDEX} \ + --config.experiment_kwargs.config.dataset_config.num_k_fold_splits=10 \ + --config.experiment_kwargs.config.dataset_config.cached_conformers_file=${CACHED_CONFORMERS_FILE} \ + --config.experiment_kwargs.config.model.num_message_passing_steps=50 \ + --config.experiment_kwargs.config.model.add_relative_distance=false \ + --config.experiment_kwargs.config.model.add_relative_displacement=false \ + --config.experiment_kwargs.config.evaluation.split="valid" + +done +done + +# Each of the 40 (two for each value of k, and both conformer and non conformer) +# jobs, it will generate paths of the form with the early-stopped models: +# RESTORE_PATH=${--config.checkpoint_dir}/models/best/step_${STEP}_${TIMESTAMP} +# These can then be used as the "RESTORE_PATHS" for `./run_pretrained_eval.sh`. + +echo "Done"