Files
deepmind-research/wikigraphs/scripts/visualize_graph.py
T
Sven Gowal 5909da5388 Added jaxline pipeline to train adversarially robust models.
PiperOrigin-RevId: 383399487
2021-07-14 15:07:21 +00:00

144 lines
5.1 KiB
Python

# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# WikiGraphs is licensed under the terms of the Creative Commons
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) license.
#
# WikiText-103 data (unchanged) is licensed by Salesforce.com, Inc. under the
# terms of the Creative Commons Attribution-ShareAlike 4.0 International
# (CC BY-SA 4.0) license. You can find details about CC BY-SA 4.0 at:
#
# https://creativecommons.org/licenses/by-sa/4.0/legalcode
#
# Freebase data is licensed by Google LLC under the terms of the Creative
# Commons CC BY 4.0 license. You may obtain a copy of the License at:
#
# https://creativecommons.org/licenses/by/4.0/legalcode
#
# ==============================================================================
r"""Tool to visualize graphs.
You need to have the command line tool `dot` installed locally, for example by
`sudo apt-get install graphviz`.
Example usage:
python visualize_graph.py \
--logtostderr --graph_ids=0:48 --truncate_limit=500 --layout=fdp
"""
import html
import os
import textwrap
from absl import app
from absl import flags
from absl import logging
from wikigraphs.data import io_tools
from wikigraphs.data import paired_dataset as pd
FLAGS = flags.FLAGS
flags.DEFINE_string('subset', 'valid', 'Which subset to choose graphs from.')
flags.DEFINE_string('graph_ids', '', 'A comma-separated string of graph IDs'
' (0-based), for example `1,2,3`. Or alternatively a'
' range, e.g. `0:10` which is equivalent to'
' `0,1,2,3,...,9`.')
flags.DEFINE_string('version', 'max256', 'Which version of data to load.')
flags.DEFINE_string('data_dir', '', 'Path to a directory that contains the raw'
' paired data, if provided.')
flags.DEFINE_string('output_dir', '/tmp/graph_vis', 'Output directory to save'
' the visualized graphs.')
flags.DEFINE_integer('truncate_limit', -1, 'Maximum length for graph nodes in'
' visualization.')
flags.DEFINE_string('layout', 'fdp', 'Which one of the dot layout to use.')
def truncate(s: str) -> str:
if FLAGS.truncate_limit > 0 and len(s) > FLAGS.truncate_limit:
s = s[:FLAGS.truncate_limit] + '...'
return s
def format_label(s: str, width: int = 40) -> str:
"""Format a node / edge label."""
s = io_tools.normalize_freebase_string(s)
s = truncate(s)
lines = s.split('\\n')
output_lines = []
for line in lines:
line = html.escape(line)
if width > 0:
output_lines += textwrap.wrap(line, width)
else:
output_lines.append(line)
return '<' + '<br/>'.join(output_lines) + '>'
def graph_to_dot(graph_text_pair: io_tools.GraphTextPair) -> str:
"""Convert a graph to a dot file."""
dot = ['digraph {', 'node [shape=rect];']
graph = pd.Graph.from_edges(graph_text_pair.edges)
center_node_id = graph.node2id(graph_text_pair.center_node)
for i, n in enumerate(graph.nodes()):
color = '#f5dc98' if i == center_node_id else (
'#b0ffad' if not(n[0] == '"' and n[-1] == '"') else '#ffffff')
label = format_label(n)
dot.append(f'{i} [ label = {label}, fillcolor="{color}", style="filled"];')
for i, j, e in graph.edges():
dot.append(f'{i} -> {j} [ label = {format_label(e, width=0)} ];')
dot.append('}')
return '\n'.join(dot)
def visualize_graph(graph_text_pair: io_tools.GraphTextPair,
graph_id: int,
output_dir: str):
"""Visualize a graph and save the visualization to the specified directory."""
dot = graph_to_dot(graph_text_pair)
output_file = os.path.join(output_dir, f'{graph_id}.dot')
logging.info('Writing output to %s', output_file)
with open(output_file, 'w') as f:
f.write(dot)
pdf_output = os.path.join(output_dir, f'{graph_id}.pdf')
os.system(f'dot -K{FLAGS.layout} -Tpdf -o {pdf_output} {output_file}')
def main(_):
logging.info('Loading the %s set of data.', FLAGS.subset)
pairs = list(pd.RawDataset(subset=FLAGS.subset,
data_dir=FLAGS.data_dir or None,
shuffle_data=False,
version=FLAGS.version))
logging.info('Loaded %d graph-text pairs.')
if ':' in FLAGS.graph_ids:
start, end = [int(i) for i in FLAGS.graph_ids.split(':')]
graph_ids = list(range(start, end))
else:
graph_ids = [int(i) for i in FLAGS.graph_ids.split(',')]
logging.info('Visualizing graphs with ID %r', graph_ids)
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
for gid in graph_ids:
visualize_graph(pairs[gid], gid, FLAGS.output_dir)
if __name__ == '__main__':
app.run(main)