mirror of
https://github.com/google-deepmind/deepmind-research.git
synced 2026-05-09 21:07:49 +08:00
Enable Travis tests for geomancer
* Added a flag to disable plotting when running headless on Travis PiperOrigin-RevId: 368229292
This commit is contained in:
committed by
Diego de Las Casas
parent
d4a9a684cb
commit
20e06f144a
@@ -12,6 +12,7 @@ env:
|
|||||||
- PROJECT="avae"
|
- PROJECT="avae"
|
||||||
# - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable
|
# - PROJECT="cs_gan" # TODO(b/184845450): Fix and re-enable
|
||||||
- PROJECT="gated_linear_networks"
|
- PROJECT="gated_linear_networks"
|
||||||
|
- PROJECT="geomancer"
|
||||||
- PROJECT="iodine"
|
- PROJECT="iodine"
|
||||||
- PROJECT="kfac_ferminet_alpha"
|
- PROJECT="kfac_ferminet_alpha"
|
||||||
- PROJECT="learning_to_simulate"
|
- PROJECT="learning_to_simulate"
|
||||||
|
|||||||
Regular → Executable
+6
-6
@@ -13,9 +13,9 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
python3 -m venv geomancer-venv
|
python3 -m venv /tmp/geomancer-venv
|
||||||
source geomancer-venv/bin/activate
|
source /tmp/geomancer-venv/bin/activate
|
||||||
pip3 install .
|
pip3 install -U pip
|
||||||
python3 geomancer_test.py
|
pip3 install geomancer/
|
||||||
python3 train.py
|
python3 -m geomancer.geomancer_test
|
||||||
deactivate
|
python3 geomancer/train.py --plot=False
|
||||||
|
|||||||
+50
-44
@@ -28,11 +28,14 @@ import numpy as np
|
|||||||
from scipy.stats import special_ortho_group
|
from scipy.stats import special_ortho_group
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
flags.DEFINE_list('specification', ['S^2', 'S^2'], 'List of submanifolds')
|
SPECIFICATION = flags.DEFINE_list(
|
||||||
flags.DEFINE_integer('npts', 1000, 'Number of data points')
|
name='specification', default=['S^2', 'S^2'], help='List of submanifolds')
|
||||||
flags.DEFINE_boolean('rotate', False, 'Apply random rotation to the data')
|
NPTS = flags.DEFINE_integer(
|
||||||
|
name='npts', default=1000, help='Number of data points')
|
||||||
FLAGS = flags.FLAGS
|
ROTATE = flags.DEFINE_boolean(
|
||||||
|
name='rotate', default=False, help='Apply random rotation to the data')
|
||||||
|
PLOT = flags.DEFINE_boolean(
|
||||||
|
name='plot', default=True, help='Whether to enable plotting')
|
||||||
|
|
||||||
|
|
||||||
def make_so_tangent(q):
|
def make_so_tangent(q):
|
||||||
@@ -139,8 +142,8 @@ def make_product_manifold(specification, npts):
|
|||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
# Generate data and run GEOMANCER
|
# Generate data and run GEOMANCER
|
||||||
data, dim, tangents = make_product_manifold(FLAGS.specification, FLAGS.npts)
|
data, dim, tangents = make_product_manifold(SPECIFICATION.value, NPTS.value)
|
||||||
if FLAGS.rotate:
|
if ROTATE.value:
|
||||||
rot, _ = np.linalg.qr(np.random.randn(data.shape[1], data.shape[1]))
|
rot, _ = np.linalg.qr(np.random.randn(data.shape[1], data.shape[1]))
|
||||||
data_rot = data @ rot.T
|
data_rot = data @ rot.T
|
||||||
components, spectrum = geomancer.fit(data_rot, dim)
|
components, spectrum = geomancer.fit(data_rot, dim)
|
||||||
@@ -149,46 +152,49 @@ def main(_):
|
|||||||
components, spectrum = geomancer.fit(data, dim)
|
components, spectrum = geomancer.fit(data, dim)
|
||||||
errors = geomancer.eval_aligned(components, tangents)
|
errors = geomancer.eval_aligned(components, tangents)
|
||||||
|
|
||||||
# Plot spectrum
|
|
||||||
plt.figure(figsize=(8, 6))
|
|
||||||
plt.scatter(np.arange(len(spectrum)), spectrum, s=100)
|
|
||||||
largest_gap = np.argmax(spectrum[1:]-spectrum[:-1]) + 1
|
|
||||||
plt.axvline(largest_gap, linewidth=2, c='r')
|
|
||||||
plt.xticks([])
|
|
||||||
plt.yticks(fontsize=18)
|
|
||||||
plt.xlabel('Index', fontsize=24)
|
|
||||||
plt.ylabel('Eigenvalue', fontsize=24)
|
|
||||||
plt.title('GeoManCEr Eigenvalue Spectrum', fontsize=24)
|
|
||||||
|
|
||||||
# Plot subspace bases
|
|
||||||
fig = plt.figure(figsize=(8, 6))
|
|
||||||
bases = components[0]
|
|
||||||
gs = gridspec.GridSpec(1, len(bases),
|
|
||||||
width_ratios=[b.shape[1] for b in bases])
|
|
||||||
for i in range(len(bases)):
|
|
||||||
ax = plt.subplot(gs[i])
|
|
||||||
ax.imshow(bases[i])
|
|
||||||
ax.set_xticks([])
|
|
||||||
ax.set_yticks([])
|
|
||||||
ax.set_title(r'$T_{\mathbf{x}_1}\mathcal{M}_%d$' % (i+1), fontsize=18)
|
|
||||||
fig.canvas.set_window_title('GeoManCEr Results')
|
|
||||||
|
|
||||||
# Plot ground truth
|
|
||||||
fig = plt.figure(figsize=(8, 6))
|
|
||||||
gs = gridspec.GridSpec(1, len(tangents),
|
|
||||||
width_ratios=[b.shape[2] for b in tangents])
|
|
||||||
for i, spec in enumerate(FLAGS.specification):
|
|
||||||
ax = plt.subplot(gs[i])
|
|
||||||
ax.imshow(tangents[i][0])
|
|
||||||
ax.set_xticks([])
|
|
||||||
ax.set_yticks([])
|
|
||||||
ax.set_title(r'$T_{\mathbf{x}_1}%s$' % spec, fontsize=18)
|
|
||||||
fig.canvas.set_window_title('Ground Truth')
|
|
||||||
|
|
||||||
logging.info('Error between subspaces: %.2f +/- %.2f radians',
|
logging.info('Error between subspaces: %.2f +/- %.2f radians',
|
||||||
np.mean(errors),
|
np.mean(errors),
|
||||||
np.std(errors))
|
np.std(errors))
|
||||||
plt.show()
|
|
||||||
|
if PLOT.value:
|
||||||
|
|
||||||
|
# Plot spectrum
|
||||||
|
plt.figure(figsize=(8, 6))
|
||||||
|
plt.scatter(np.arange(len(spectrum)), spectrum, s=100)
|
||||||
|
largest_gap = np.argmax(spectrum[1:]-spectrum[:-1]) + 1
|
||||||
|
plt.axvline(largest_gap, linewidth=2, c='r')
|
||||||
|
plt.xticks([])
|
||||||
|
plt.yticks(fontsize=18)
|
||||||
|
plt.xlabel('Index', fontsize=24)
|
||||||
|
plt.ylabel('Eigenvalue', fontsize=24)
|
||||||
|
plt.title('GeoManCEr Eigenvalue Spectrum', fontsize=24)
|
||||||
|
|
||||||
|
# Plot subspace bases
|
||||||
|
fig = plt.figure(figsize=(8, 6))
|
||||||
|
bases = components[0]
|
||||||
|
gs = gridspec.GridSpec(1, len(bases),
|
||||||
|
width_ratios=[b.shape[1] for b in bases])
|
||||||
|
for i in range(len(bases)):
|
||||||
|
ax = plt.subplot(gs[i])
|
||||||
|
ax.imshow(bases[i])
|
||||||
|
ax.set_xticks([])
|
||||||
|
ax.set_yticks([])
|
||||||
|
ax.set_title(r'$T_{\mathbf{x}_1}\mathcal{M}_%d$' % (i+1), fontsize=18)
|
||||||
|
fig.canvas.set_window_title('GeoManCEr Results')
|
||||||
|
|
||||||
|
# Plot ground truth
|
||||||
|
fig = plt.figure(figsize=(8, 6))
|
||||||
|
gs = gridspec.GridSpec(1, len(tangents),
|
||||||
|
width_ratios=[b.shape[2] for b in tangents])
|
||||||
|
for i, spec in enumerate(SPECIFICATION.value):
|
||||||
|
ax = plt.subplot(gs[i])
|
||||||
|
ax.imshow(tangents[i][0])
|
||||||
|
ax.set_xticks([])
|
||||||
|
ax.set_yticks([])
|
||||||
|
ax.set_title(r'$T_{\mathbf{x}_1}%s$' % spec, fontsize=18)
|
||||||
|
fig.canvas.set_window_title('Ground Truth')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user