Update README.md.

PiperOrigin-RevId: 390606818
This commit is contained in:
Luyu Wang
2021-08-13 15:19:23 +01:00
committed by Diego de Las Casas
parent 1243baa0b0
commit 17b5dbaf88
2 changed files with 58 additions and 20 deletions
+57 -19
View File
@@ -11,12 +11,12 @@ conditioned on graph and generate graphs given text.
[Jax](https://github.com/google/jax#installation),
[Haiku](https://github.com/deepmind/dm-haiku#installation),
[Optax](https://github.com/deepmind/dm-haiku#installation), and
[Optax](https://optax.readthedocs.io/en/latest/#installation), and
[Jraph](https://github.com/deepmind/jraph) are needed for this package. It has
been developed and tested on python 3 with the following packages:
* Jax==0.2.13
* Haiku==0.0.5
* Haiku==0.0.5.dev
* Optax==0.0.6
* Jraph==0.0.1.dev
@@ -167,7 +167,56 @@ it elsewhere.
## Run baseline models
Note: our code supports training with multiple GPUs.
To quickly test-run a small model with 1 GPU:
```base
python main.py --model_type=graph2text \
--dataset=freebase2wikitext \
--checkpoint_dir=/tmp/graph2text \
--job_mode=train \
--train_batch_size=2 \
--gnn_num_layers=1 \
--num_gpus=1
```
To run the default baseline unconditional TransformerXL on Wikigraphs with 8
GPUs:
```base
python main.py --model_type=text \
--dataset=freebase2wikitext \
--checkpoint_dir=/tmp/text \
--job_mode=train \
--train_batch_size=64 \
--gnn_num_layers=1 \
--num_gpus=8
```
To run the default baseline BoW-based TransformerXL on Wikigraphs with 8
GPUs:
```base
python main.py --model_type=bow2text \
--dataset=freebase2wikitext \
--checkpoint_dir=/tmp/bow2text \
--job_mode=train \
--train_batch_size=64 \
--gnn_num_layers=1 \
--num_gpus=8
```
To run the default baseline Nodes-only GNN-based TransformerXL on Wikigraphs
with 8 GPUs:
```base
python main.py --model_type=bow2text \
--dataset=freebase2wikitext \
--checkpoint_dir=/tmp/bow2text \
--job_mode=train \
--train_batch_size=64 \
--gnn_num_layers=0 \
--num_gpus=8
```
To run the default baseline GNN-based TransformerXL on Wikigraphs with 8
GPUs:
@@ -182,22 +231,11 @@ python main.py --model_type=graph2text \
--num_gpus=8
```
We ran our experiments in the paper using 8 Nvidia V100 GPUs. To allow for
batch parallization for the GNN-based (graph2text) model, we pad graphs to
the largest graph in the batch. The full run takes almost 4 days. BoW- and
nodes-based models can be trained within 14 hours because there is no
additional padding.
Or to quickly test-run a small model:
```base
python main.py --model_type=graph2text \
--dataset=freebase2wikitext \
--checkpoint_dir=/tmp/graph2text \
--job_mode=train \
--train_batch_size=2 \
--gnn_num_layers=1
```
We ran our experiments in the paper using 8 Nvidia V100 GPUs. Reduce the batch
size if the model does not fit into memory. To allow for batch parallization for
the GNN-based (graph2text) model, we pad graphs to the largest graph in the
batch. The full run takes almost 4 days. BoW- and nodes-based models can be
trained within 14 hours because there is no additional padding.
To evaluate the model on the validation set (this only uses 1 GPU):
+1 -1
View File
@@ -33,7 +33,7 @@ from setuptools import setup
setup(
name='wikigraphs',
version='0.0.2',
version='0.1.0',
description='A Wikipedia - knowledge graph paired dataset.',
url='https://github.com/deepmind/deepmind-research/tree/master/wikigraphs',
author='DeepMind',