From 20e06f144ac0b5978f75f0214df17d042112a755 Mon Sep 17 00:00:00 2001 From: Alistair Muldal Date: Tue, 13 Apr 2021 17:21:03 +0100 Subject: [PATCH] Enable Travis tests for `geomancer` * Added a flag to disable plotting when running headless on Travis PiperOrigin-RevId: 368229292 --- .travis.yml | 1 + geomancer/run.sh | 12 +++--- geomancer/train.py | 94 ++++++++++++++++++++++++---------------------- 3 files changed, 57 insertions(+), 50 deletions(-) mode change 100644 => 100755 geomancer/run.sh diff --git a/.travis.yml b/.travis.yml index d161e2d..41bc85f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,6 +12,7 @@ env: - PROJECT="avae" # - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable - PROJECT="gated_linear_networks" + - PROJECT="geomancer" - PROJECT="iodine" - PROJECT="kfac_ferminet_alpha" - PROJECT="learning_to_simulate" diff --git a/geomancer/run.sh b/geomancer/run.sh old mode 100644 new mode 100755 index 557fdb7..7bc614c --- a/geomancer/run.sh +++ b/geomancer/run.sh @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -python3 -m venv geomancer-venv -source geomancer-venv/bin/activate -pip3 install . -python3 geomancer_test.py -python3 train.py -deactivate +python3 -m venv /tmp/geomancer-venv +source /tmp/geomancer-venv/bin/activate +pip3 install -U pip +pip3 install geomancer/ +python3 -m geomancer.geomancer_test +python3 geomancer/train.py --plot=False diff --git a/geomancer/train.py b/geomancer/train.py index 2ebd541..e72a817 100644 --- a/geomancer/train.py +++ b/geomancer/train.py @@ -28,11 +28,14 @@ import numpy as np from scipy.stats import special_ortho_group from tqdm import tqdm -flags.DEFINE_list('specification', ['S^2', 'S^2'], 'List of submanifolds') -flags.DEFINE_integer('npts', 1000, 'Number of data points') -flags.DEFINE_boolean('rotate', False, 'Apply random rotation to the data') - -FLAGS = flags.FLAGS +SPECIFICATION = flags.DEFINE_list( + name='specification', default=['S^2', 'S^2'], help='List of submanifolds') +NPTS = flags.DEFINE_integer( + name='npts', default=1000, help='Number of data points') +ROTATE = flags.DEFINE_boolean( + name='rotate', default=False, help='Apply random rotation to the data') +PLOT = flags.DEFINE_boolean( + name='plot', default=True, help='Whether to enable plotting') def make_so_tangent(q): @@ -139,8 +142,8 @@ def make_product_manifold(specification, npts): def main(_): # Generate data and run GEOMANCER - data, dim, tangents = make_product_manifold(FLAGS.specification, FLAGS.npts) - if FLAGS.rotate: + data, dim, tangents = make_product_manifold(SPECIFICATION.value, NPTS.value) + if ROTATE.value: rot, _ = np.linalg.qr(np.random.randn(data.shape[1], data.shape[1])) data_rot = data @ rot.T components, spectrum = geomancer.fit(data_rot, dim) @@ -149,46 +152,49 @@ def main(_): components, spectrum = geomancer.fit(data, dim) errors = geomancer.eval_aligned(components, tangents) - # Plot spectrum - plt.figure(figsize=(8, 6)) - plt.scatter(np.arange(len(spectrum)), spectrum, s=100) - largest_gap = np.argmax(spectrum[1:]-spectrum[:-1]) + 1 - plt.axvline(largest_gap, linewidth=2, c='r') - plt.xticks([]) - plt.yticks(fontsize=18) - plt.xlabel('Index', fontsize=24) - plt.ylabel('Eigenvalue', fontsize=24) - plt.title('GeoManCEr Eigenvalue Spectrum', fontsize=24) - - # Plot subspace bases - fig = plt.figure(figsize=(8, 6)) - bases = components[0] - gs = gridspec.GridSpec(1, len(bases), - width_ratios=[b.shape[1] for b in bases]) - for i in range(len(bases)): - ax = plt.subplot(gs[i]) - ax.imshow(bases[i]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title(r'$T_{\mathbf{x}_1}\mathcal{M}_%d$' % (i+1), fontsize=18) - fig.canvas.set_window_title('GeoManCEr Results') - - # Plot ground truth - fig = plt.figure(figsize=(8, 6)) - gs = gridspec.GridSpec(1, len(tangents), - width_ratios=[b.shape[2] for b in tangents]) - for i, spec in enumerate(FLAGS.specification): - ax = plt.subplot(gs[i]) - ax.imshow(tangents[i][0]) - ax.set_xticks([]) - ax.set_yticks([]) - ax.set_title(r'$T_{\mathbf{x}_1}%s$' % spec, fontsize=18) - fig.canvas.set_window_title('Ground Truth') - logging.info('Error between subspaces: %.2f +/- %.2f radians', np.mean(errors), np.std(errors)) - plt.show() + + if PLOT.value: + + # Plot spectrum + plt.figure(figsize=(8, 6)) + plt.scatter(np.arange(len(spectrum)), spectrum, s=100) + largest_gap = np.argmax(spectrum[1:]-spectrum[:-1]) + 1 + plt.axvline(largest_gap, linewidth=2, c='r') + plt.xticks([]) + plt.yticks(fontsize=18) + plt.xlabel('Index', fontsize=24) + plt.ylabel('Eigenvalue', fontsize=24) + plt.title('GeoManCEr Eigenvalue Spectrum', fontsize=24) + + # Plot subspace bases + fig = plt.figure(figsize=(8, 6)) + bases = components[0] + gs = gridspec.GridSpec(1, len(bases), + width_ratios=[b.shape[1] for b in bases]) + for i in range(len(bases)): + ax = plt.subplot(gs[i]) + ax.imshow(bases[i]) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(r'$T_{\mathbf{x}_1}\mathcal{M}_%d$' % (i+1), fontsize=18) + fig.canvas.set_window_title('GeoManCEr Results') + + # Plot ground truth + fig = plt.figure(figsize=(8, 6)) + gs = gridspec.GridSpec(1, len(tangents), + width_ratios=[b.shape[2] for b in tangents]) + for i, spec in enumerate(SPECIFICATION.value): + ax = plt.subplot(gs[i]) + ax.imshow(tangents[i][0]) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_title(r'$T_{\mathbf{x}_1}%s$' % spec, fontsize=18) + fig.canvas.set_window_title('Ground Truth') + + plt.show() if __name__ == '__main__':