Enable Travis tests for geomancer

* Added a flag to disable plotting when running headless on Travis

PiperOrigin-RevId: 368229292
This commit is contained in:
Alistair Muldal
2021-04-13 17:21:03 +01:00
committed by Diego de Las Casas
parent d4a9a684cb
commit 20e06f144a
3 changed files with 57 additions and 50 deletions
+1
View File
@@ -12,6 +12,7 @@ env:
- PROJECT="avae" - PROJECT="avae"
# - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable # - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable
- PROJECT="gated_linear_networks" - PROJECT="gated_linear_networks"
- PROJECT="geomancer"
- PROJECT="iodine" - PROJECT="iodine"
- PROJECT="kfac_ferminet_alpha" - PROJECT="kfac_ferminet_alpha"
- PROJECT="learning_to_simulate" - PROJECT="learning_to_simulate"
Regular → Executable
+6 -6
View File
@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
python3 -m venv geomancer-venv python3 -m venv /tmp/geomancer-venv
source geomancer-venv/bin/activate source /tmp/geomancer-venv/bin/activate
pip3 install . pip3 install -U pip
python3 geomancer_test.py pip3 install geomancer/
python3 train.py python3 -m geomancer.geomancer_test
deactivate python3 geomancer/train.py --plot=False
+50 -44
View File
@@ -28,11 +28,14 @@ import numpy as np
from scipy.stats import special_ortho_group from scipy.stats import special_ortho_group
from tqdm import tqdm from tqdm import tqdm
flags.DEFINE_list('specification', ['S^2', 'S^2'], 'List of submanifolds') SPECIFICATION = flags.DEFINE_list(
flags.DEFINE_integer('npts', 1000, 'Number of data points') name='specification', default=['S^2', 'S^2'], help='List of submanifolds')
flags.DEFINE_boolean('rotate', False, 'Apply random rotation to the data') NPTS = flags.DEFINE_integer(
name='npts', default=1000, help='Number of data points')
FLAGS = flags.FLAGS ROTATE = flags.DEFINE_boolean(
name='rotate', default=False, help='Apply random rotation to the data')
PLOT = flags.DEFINE_boolean(
name='plot', default=True, help='Whether to enable plotting')
def make_so_tangent(q): def make_so_tangent(q):
@@ -139,8 +142,8 @@ def make_product_manifold(specification, npts):
def main(_): def main(_):
# Generate data and run GEOMANCER # Generate data and run GEOMANCER
data, dim, tangents = make_product_manifold(FLAGS.specification, FLAGS.npts) data, dim, tangents = make_product_manifold(SPECIFICATION.value, NPTS.value)
if FLAGS.rotate: if ROTATE.value:
rot, _ = np.linalg.qr(np.random.randn(data.shape[1], data.shape[1])) rot, _ = np.linalg.qr(np.random.randn(data.shape[1], data.shape[1]))
data_rot = data @ rot.T data_rot = data @ rot.T
components, spectrum = geomancer.fit(data_rot, dim) components, spectrum = geomancer.fit(data_rot, dim)
@@ -149,46 +152,49 @@ def main(_):
components, spectrum = geomancer.fit(data, dim) components, spectrum = geomancer.fit(data, dim)
errors = geomancer.eval_aligned(components, tangents) errors = geomancer.eval_aligned(components, tangents)
# Plot spectrum
plt.figure(figsize=(8, 6))
plt.scatter(np.arange(len(spectrum)), spectrum, s=100)
largest_gap = np.argmax(spectrum[1:]-spectrum[:-1]) + 1
plt.axvline(largest_gap, linewidth=2, c='r')
plt.xticks([])
plt.yticks(fontsize=18)
plt.xlabel('Index', fontsize=24)
plt.ylabel('Eigenvalue', fontsize=24)
plt.title('GeoManCEr Eigenvalue Spectrum', fontsize=24)
# Plot subspace bases
fig = plt.figure(figsize=(8, 6))
bases = components[0]
gs = gridspec.GridSpec(1, len(bases),
width_ratios=[b.shape[1] for b in bases])
for i in range(len(bases)):
ax = plt.subplot(gs[i])
ax.imshow(bases[i])
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(r'$T_{\mathbf{x}_1}\mathcal{M}_%d$' % (i+1), fontsize=18)
fig.canvas.set_window_title('GeoManCEr Results')
# Plot ground truth
fig = plt.figure(figsize=(8, 6))
gs = gridspec.GridSpec(1, len(tangents),
width_ratios=[b.shape[2] for b in tangents])
for i, spec in enumerate(FLAGS.specification):
ax = plt.subplot(gs[i])
ax.imshow(tangents[i][0])
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(r'$T_{\mathbf{x}_1}%s$' % spec, fontsize=18)
fig.canvas.set_window_title('Ground Truth')
logging.info('Error between subspaces: %.2f +/- %.2f radians', logging.info('Error between subspaces: %.2f +/- %.2f radians',
np.mean(errors), np.mean(errors),
np.std(errors)) np.std(errors))
plt.show()
if PLOT.value:
# Plot spectrum
plt.figure(figsize=(8, 6))
plt.scatter(np.arange(len(spectrum)), spectrum, s=100)
largest_gap = np.argmax(spectrum[1:]-spectrum[:-1]) + 1
plt.axvline(largest_gap, linewidth=2, c='r')
plt.xticks([])
plt.yticks(fontsize=18)
plt.xlabel('Index', fontsize=24)
plt.ylabel('Eigenvalue', fontsize=24)
plt.title('GeoManCEr Eigenvalue Spectrum', fontsize=24)
# Plot subspace bases
fig = plt.figure(figsize=(8, 6))
bases = components[0]
gs = gridspec.GridSpec(1, len(bases),
width_ratios=[b.shape[1] for b in bases])
for i in range(len(bases)):
ax = plt.subplot(gs[i])
ax.imshow(bases[i])
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(r'$T_{\mathbf{x}_1}\mathcal{M}_%d$' % (i+1), fontsize=18)
fig.canvas.set_window_title('GeoManCEr Results')
# Plot ground truth
fig = plt.figure(figsize=(8, 6))
gs = gridspec.GridSpec(1, len(tangents),
width_ratios=[b.shape[2] for b in tangents])
for i, spec in enumerate(SPECIFICATION.value):
ax = plt.subplot(gs[i])
ax.imshow(tangents[i][0])
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(r'$T_{\mathbf{x}_1}%s$' % spec, fontsize=18)
fig.canvas.set_window_title('Ground Truth')
plt.show()
if __name__ == '__main__': if __name__ == '__main__':