mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Enable Travis tests for geomancer
* Added a flag to disable plotting when running headless on Travis PiperOrigin-RevId: 368229292
This commit is contained in:
committed by
Diego de Las Casas
parent
d4a9a684cb
commit
20e06f144a
@@ -12,6 +12,7 @@ env:
|
||||
- PROJECT="avae"
|
||||
# - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable
|
||||
- PROJECT="gated_linear_networks"
|
||||
- PROJECT="geomancer"
|
||||
- PROJECT="iodine"
|
||||
- PROJECT="kfac_ferminet_alpha"
|
||||
- PROJECT="learning_to_simulate"
|
||||
|
||||
Regular → Executable
+6
-6
@@ -13,9 +13,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
python3 -m venv geomancer-venv
|
||||
source geomancer-venv/bin/activate
|
||||
pip3 install .
|
||||
python3 geomancer_test.py
|
||||
python3 train.py
|
||||
deactivate
|
||||
python3 -m venv /tmp/geomancer-venv
|
||||
source /tmp/geomancer-venv/bin/activate
|
||||
pip3 install -U pip
|
||||
pip3 install geomancer/
|
||||
python3 -m geomancer.geomancer_test
|
||||
python3 geomancer/train.py --plot=False
|
||||
|
||||
+50
-44
@@ -28,11 +28,14 @@ import numpy as np
|
||||
from scipy.stats import special_ortho_group
|
||||
from tqdm import tqdm
|
||||
|
||||
flags.DEFINE_list('specification', ['S^2', 'S^2'], 'List of submanifolds')
|
||||
flags.DEFINE_integer('npts', 1000, 'Number of data points')
|
||||
flags.DEFINE_boolean('rotate', False, 'Apply random rotation to the data')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
SPECIFICATION = flags.DEFINE_list(
|
||||
name='specification', default=['S^2', 'S^2'], help='List of submanifolds')
|
||||
NPTS = flags.DEFINE_integer(
|
||||
name='npts', default=1000, help='Number of data points')
|
||||
ROTATE = flags.DEFINE_boolean(
|
||||
name='rotate', default=False, help='Apply random rotation to the data')
|
||||
PLOT = flags.DEFINE_boolean(
|
||||
name='plot', default=True, help='Whether to enable plotting')
|
||||
|
||||
|
||||
def make_so_tangent(q):
|
||||
@@ -139,8 +142,8 @@ def make_product_manifold(specification, npts):
|
||||
|
||||
def main(_):
|
||||
# Generate data and run GEOMANCER
|
||||
data, dim, tangents = make_product_manifold(FLAGS.specification, FLAGS.npts)
|
||||
if FLAGS.rotate:
|
||||
data, dim, tangents = make_product_manifold(SPECIFICATION.value, NPTS.value)
|
||||
if ROTATE.value:
|
||||
rot, _ = np.linalg.qr(np.random.randn(data.shape[1], data.shape[1]))
|
||||
data_rot = data @ rot.T
|
||||
components, spectrum = geomancer.fit(data_rot, dim)
|
||||
@@ -149,46 +152,49 @@ def main(_):
|
||||
components, spectrum = geomancer.fit(data, dim)
|
||||
errors = geomancer.eval_aligned(components, tangents)
|
||||
|
||||
# Plot spectrum
|
||||
plt.figure(figsize=(8, 6))
|
||||
plt.scatter(np.arange(len(spectrum)), spectrum, s=100)
|
||||
largest_gap = np.argmax(spectrum[1:]-spectrum[:-1]) + 1
|
||||
plt.axvline(largest_gap, linewidth=2, c='r')
|
||||
plt.xticks([])
|
||||
plt.yticks(fontsize=18)
|
||||
plt.xlabel('Index', fontsize=24)
|
||||
plt.ylabel('Eigenvalue', fontsize=24)
|
||||
plt.title('GeoManCEr Eigenvalue Spectrum', fontsize=24)
|
||||
|
||||
# Plot subspace bases
|
||||
fig = plt.figure(figsize=(8, 6))
|
||||
bases = components[0]
|
||||
gs = gridspec.GridSpec(1, len(bases),
|
||||
width_ratios=[b.shape[1] for b in bases])
|
||||
for i in range(len(bases)):
|
||||
ax = plt.subplot(gs[i])
|
||||
ax.imshow(bases[i])
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
ax.set_title(r'$T_{\mathbf{x}_1}\mathcal{M}_%d$' % (i+1), fontsize=18)
|
||||
fig.canvas.set_window_title('GeoManCEr Results')
|
||||
|
||||
# Plot ground truth
|
||||
fig = plt.figure(figsize=(8, 6))
|
||||
gs = gridspec.GridSpec(1, len(tangents),
|
||||
width_ratios=[b.shape[2] for b in tangents])
|
||||
for i, spec in enumerate(FLAGS.specification):
|
||||
ax = plt.subplot(gs[i])
|
||||
ax.imshow(tangents[i][0])
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
ax.set_title(r'$T_{\mathbf{x}_1}%s$' % spec, fontsize=18)
|
||||
fig.canvas.set_window_title('Ground Truth')
|
||||
|
||||
logging.info('Error between subspaces: %.2f +/- %.2f radians',
|
||||
np.mean(errors),
|
||||
np.std(errors))
|
||||
plt.show()
|
||||
|
||||
if PLOT.value:
|
||||
|
||||
# Plot spectrum
|
||||
plt.figure(figsize=(8, 6))
|
||||
plt.scatter(np.arange(len(spectrum)), spectrum, s=100)
|
||||
largest_gap = np.argmax(spectrum[1:]-spectrum[:-1]) + 1
|
||||
plt.axvline(largest_gap, linewidth=2, c='r')
|
||||
plt.xticks([])
|
||||
plt.yticks(fontsize=18)
|
||||
plt.xlabel('Index', fontsize=24)
|
||||
plt.ylabel('Eigenvalue', fontsize=24)
|
||||
plt.title('GeoManCEr Eigenvalue Spectrum', fontsize=24)
|
||||
|
||||
# Plot subspace bases
|
||||
fig = plt.figure(figsize=(8, 6))
|
||||
bases = components[0]
|
||||
gs = gridspec.GridSpec(1, len(bases),
|
||||
width_ratios=[b.shape[1] for b in bases])
|
||||
for i in range(len(bases)):
|
||||
ax = plt.subplot(gs[i])
|
||||
ax.imshow(bases[i])
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
ax.set_title(r'$T_{\mathbf{x}_1}\mathcal{M}_%d$' % (i+1), fontsize=18)
|
||||
fig.canvas.set_window_title('GeoManCEr Results')
|
||||
|
||||
# Plot ground truth
|
||||
fig = plt.figure(figsize=(8, 6))
|
||||
gs = gridspec.GridSpec(1, len(tangents),
|
||||
width_ratios=[b.shape[2] for b in tangents])
|
||||
for i, spec in enumerate(SPECIFICATION.value):
|
||||
ax = plt.subplot(gs[i])
|
||||
ax.imshow(tangents[i][0])
|
||||
ax.set_xticks([])
|
||||
ax.set_yticks([])
|
||||
ax.set_title(r'$T_{\mathbf{x}_1}%s$' % spec, fontsize=18)
|
||||
fig.canvas.set_window_title('Ground Truth')
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user