Update README.md.

PiperOrigin-RevId: 390606818
This commit is contained in:
Luyu Wang
2021-08-13 15:19:23 +01:00
committed by Diego de Las Casas
parent 1243baa0b0
commit 17b5dbaf88
2 changed files with 58 additions and 20 deletions
+57 -19
View File
@@ -11,12 +11,12 @@ conditioned on graph and generate graphs given text.
[Jax](https://github.com/google/jax#installation), [Jax](https://github.com/google/jax#installation),
[Haiku](https://github.com/deepmind/dm-haiku#installation), [Haiku](https://github.com/deepmind/dm-haiku#installation),
[Optax](https://github.com/deepmind/dm-haiku#installation), and [Optax](https://optax.readthedocs.io/en/latest/#installation), and
[Jraph](https://github.com/deepmind/jraph) are needed for this package. It has [Jraph](https://github.com/deepmind/jraph) are needed for this package. It has
been developed and tested on python 3 with the following packages: been developed and tested on python 3 with the following packages:
* Jax==0.2.13 * Jax==0.2.13
* Haiku==0.0.5 * Haiku==0.0.5.dev
* Optax==0.0.6 * Optax==0.0.6
* Jraph==0.0.1.dev * Jraph==0.0.1.dev
@@ -167,7 +167,56 @@ it elsewhere.
## Run baseline models ## Run baseline models
Note: our code supports training with multiple GPUs. To quickly test-run a small model with 1 GPU:
```base
python main.py --model_type=graph2text \
--dataset=freebase2wikitext \
--checkpoint_dir=/tmp/graph2text \
--job_mode=train \
--train_batch_size=2 \
--gnn_num_layers=1 \
--num_gpus=1
```
To run the default baseline unconditional TransformerXL on Wikigraphs with 8
GPUs:
```base
python main.py --model_type=text \
--dataset=freebase2wikitext \
--checkpoint_dir=/tmp/text \
--job_mode=train \
--train_batch_size=64 \
--gnn_num_layers=1 \
--num_gpus=8
```
To run the default baseline BoW-based TransformerXL on Wikigraphs with 8
GPUs:
```base
python main.py --model_type=bow2text \
--dataset=freebase2wikitext \
--checkpoint_dir=/tmp/bow2text \
--job_mode=train \
--train_batch_size=64 \
--gnn_num_layers=1 \
--num_gpus=8
```
To run the default baseline Nodes-only GNN-based TransformerXL on Wikigraphs
with 8 GPUs:
```base
python main.py --model_type=bow2text \
--dataset=freebase2wikitext \
--checkpoint_dir=/tmp/bow2text \
--job_mode=train \
--train_batch_size=64 \
--gnn_num_layers=0 \
--num_gpus=8
```
To run the default baseline GNN-based TransformerXL on Wikigraphs with 8 To run the default baseline GNN-based TransformerXL on Wikigraphs with 8
GPUs: GPUs:
@@ -182,22 +231,11 @@ python main.py --model_type=graph2text \
--num_gpus=8 --num_gpus=8
``` ```
We ran our experiments in the paper using 8 Nvidia V100 GPUs. To allow for We ran our experiments in the paper using 8 Nvidia V100 GPUs. Reduce the batch
batch parallization for the GNN-based (graph2text) model, we pad graphs to size if the model does not fit into memory. To allow for batch parallization for
the largest graph in the batch. The full run takes almost 4 days. BoW- and the GNN-based (graph2text) model, we pad graphs to the largest graph in the
nodes-based models can be trained within 14 hours because there is no batch. The full run takes almost 4 days. BoW- and nodes-based models can be
additional padding. trained within 14 hours because there is no additional padding.
Or to quickly test-run a small model:
```base
python main.py --model_type=graph2text \
--dataset=freebase2wikitext \
--checkpoint_dir=/tmp/graph2text \
--job_mode=train \
--train_batch_size=2 \
--gnn_num_layers=1
```
To evaluate the model on the validation set (this only uses 1 GPU): To evaluate the model on the validation set (this only uses 1 GPU):
+1 -1
View File
@@ -33,7 +33,7 @@ from setuptools import setup
setup( setup(
name='wikigraphs', name='wikigraphs',
version='0.0.2', version='0.1.0',
description='A Wikipedia - knowledge graph paired dataset.', description='A Wikipedia - knowledge graph paired dataset.',
url='https://github.com/deepmind/deepmind-research/tree/master/wikigraphs', url='https://github.com/deepmind/deepmind-research/tree/master/wikigraphs',
author='DeepMind', author='DeepMind',